# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch ParlerTTS model.""" import copy import inspect import math import random from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, List import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding from transformers.activations import ACT2FN from transformers.cache_utils import ( Cache, DynamicCache, EncoderDecoderCache, SlidingWindowCache, StaticCache, ) from transformers.generation.configuration_utils import GenerationConfig, GenerationMode from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ModelOutput, Seq2SeqLMOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, is_torchdynamo_compiling, ) from transformers.utils.import_utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .dac_wrapper import DACConfig, DACModel from .logits_processors import ParlerTTSLogitsProcessor from importlib.metadata import version from packaging.version import Version is_dac_integrated_to_transformers = Version(version("transformers")) > Version("4.44.2dev") if not is_dac_integrated_to_transformers: AutoConfig.register("dac", DACConfig) else: AutoConfig.register("dac_on_the_hub", DACConfig) AutoModel.register(DACConfig, DACModel) if TYPE_CHECKING: from transformers.generation.streamers import BaseStreamer logger = logging.get_logger(__name__) if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa else: logger.warn("Flash attention 2 is not installed") _CONFIG_FOR_DOC = "ParlerTTSConfig" _CHECKPOINT_FOR_DOC = "parler-tts/parler-tts-mini-v1" MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ "parler-tts/parler-tts-mini-v1", # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts ] NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} @dataclass class ParlerTTSSeq2SeqLMOutput(ModelOutput): """ Base class for sequence-to-sequence language models outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None per_codebook_losses: Optional[List[torch.FloatTensor]] = None @dataclass class ParlerTTSCausalLMOutputWithCrossAttentions(ModelOutput): """ Base class for causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting. Only relevant if `config.is_decoder = True`. Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None per_codebook_losses: Optional[List[torch.FloatTensor]] = None def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): """Apply a delay pattern mask to the decoder input ids, only preserving predictions where the mask is set to -1, and otherwise setting to the value detailed in the mask.""" seq_len = input_ids.shape[-1] decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) return input_ids def build_delay_pattern_mask( input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int ): """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, seq_len)`: - [B, -1, -1, -1, -1, P, P, P] - [B, B, -1, -1, -1, -1, P, P] - [B, B, B, -1, -1, -1, -1, P] - [B, B, B, B, -1, -1, -1, -1] where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the mask is set to the value in the prompt: - [B, a, b, -1, -1, P, P, P] - [B, B, c, d, -1, -1, P, P] - [B, B, B, e, f, -1, -1, P] - [B, B, B, B, g, h, -1, -1] where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 tokens in our prediction. """ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1]) bsz, num_codebooks, seq_len = input_ids.shape input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 # we only apply the mask if we have a large enough seq len - otherwise we return as is if max_length < 2 * num_codebooks - 1: return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) # fill the shifted ids with the prompt entries, offset by the codebook idx for codebook in range(num_codebooks): # mono channel - loop over the codebooks one-by-one input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] # construct a pattern mask that indicates the positions of padding tokens for each codebook # first fill the upper triangular part (the EOS padding) eos_delay_pattern = torch.triu( torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1 ) # then fill the lower triangular part (the BOS padding) bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool)) bos_mask = ~(bos_delay_pattern).to(input_ids.device) eos_mask = ~(eos_delay_pattern).to(input_ids.device) mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device) input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id # find the first position to start generating - this is the first place we have the -1 token # and will always be in the first codebook (since it has no codebook offset) first_codebook_ids = input_ids[:, 0, :] start_ids = (first_codebook_ids == -1).nonzero()[:, 1] if len(start_ids) > 0: first_start_id = min(start_ids) else: # we have no tokens that need to be filled - return entire matrix of input ids first_start_id = seq_len # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) pattern_mask = input_ids.reshape(bsz * num_codebooks, -1) input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) return input_ids, pattern_mask # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @dataclass class ParlerTTSUnconditionalInput(ModelOutput): """ Args: encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the text encoder model. attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*): Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**. """ encoder_outputs: Tuple[torch.FloatTensor] = None attention_mask: torch.LongTensor = None # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() if decoder_start_token_id is None: raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->ParlerTTS class ParlerTTSSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int): super().__init__() self.embedding_dim = embedding_dim self.make_weights(num_positions, embedding_dim) def make_weights(self, num_embeddings: int, embedding_dim: int): emb_weights = self.get_embedding(num_embeddings, embedding_dim) if hasattr(self, "weights"): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) self.weights = nn.Parameter(emb_weights) self.weights.requires_grad = False self.weights.detach_() @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int): """ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) return emb.to(torch.get_default_dtype()) @torch.no_grad() def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): bsz, seq_len, _ = input_ids.size() # Create the position ids from the input token ids. position_ids = torch.arange(seq_len, device=input_ids.device) + past_key_values_length # expand embeddings if needed if seq_len > self.weights.size(0): self.make_weights(seq_len + self.offset, self.embedding_dim) return self.weights.index_select(0, position_ids.view(-1)).detach() # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->ParlerTTS class ParlerTTSRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) # Ignore copy @torch.no_grad() def forward(self, device_type, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :] # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos, sin def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: x (`torch.Tensor`): The tensor over which to apply the rope embeddings cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) x_embed = (x * cos) + (rotate_half(x) * sin) return x_embed class ParlerTTSAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper. Modified to use GQA and MQA.""" def __init__( self, embed_dim: int, num_heads: int, num_key_value_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, rope_embeddings: bool = False, layer_idx: Optional[int] = None, config: Optional[ParlerTTSDecoderConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal if layer_idx is None and is_decoder: logger.warning_once( f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.rope_embeddings = rope_embeddings def _shape_query(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, cos: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len = hidden_states.shape[:2] # get query proj query_states = self.q_proj(hidden_states) * self.scaling query_states = self._shape_query(query_states, tgt_len, bsz) if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache past_key_value.is_updated[self.layer_idx] = True past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) if not is_cross_attention: # cached key states already have rope applied - only apply to new state key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_probs, value_states) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights, past_key_value def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenFlashAttention2 with Musicgen->ParlerTTS class ParlerTTSFlashAttention2(ParlerTTSAttention): """ ParlerTTS flash attention module. This module inherits from `ParlerTTSAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, cos: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # ParlerTTSFlashAttention2 attention does not support output_attentions if isinstance(past_key_value, StaticCache): raise ValueError( "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len = hidden_states.shape[:2] # get query proj query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim) if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache past_key_value.is_updated[self.layer_idx] = True past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) if not is_cross_attention and self.rope_embeddings: # cached key states already have rope applied - only apply to new state key_states = apply_rotary_pos_emb(key_states, cos, sin) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] # # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) if query_states.dtype == torch.float32 or value_states.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, tgt_len, dropout=self.dropout ) attn_output = attn_output.reshape(bsz, tgt_len, -1) attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`float`): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen class ParlerTTSSdpaAttention(ParlerTTSAttention): def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, cos: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( "ParlerTTSModel is using ParlerTTSSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states, key_value_states=key_value_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len = hidden_states.shape[:2] # get query proj query_states = self.q_proj(hidden_states) query_states = self._shape_query(query_states, tgt_len, bsz) if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache past_key_value.is_updated[self.layer_idx] = True past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) if not is_cross_attention and self.rope_embeddings: # cached key states already have rope applied - only apply to new state key_states = apply_rotary_pos_emb(key_states, cos, sin) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value PARLERTTS_ATTENTION_CLASSES = { "eager": ParlerTTSAttention, "sdpa": ParlerTTSSdpaAttention, "flash_attention_2": ParlerTTSFlashAttention2, } class ParlerTTSDecoderLayer(nn.Module): def __init__(self, config: ParlerTTSDecoderConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.hidden_size self.self_attn = PARLERTTS_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, dropout=config.attention_dropout, is_decoder=True, is_causal=True, bias=False, rope_embeddings=config.rope_embeddings, layer_idx=layer_idx, config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) cross_attn_implementation = config._attn_implementation if config.cross_attention_implementation_strategy == "always_eager": cross_attn_implementation = "eager" elif config.cross_attention_implementation_strategy == "always_sdpa": cross_attn_implementation = "sdpa" self.encoder_attn = PARLERTTS_ATTENTION_CLASSES[cross_attn_implementation]( self.embed_dim, config.num_attention_heads, num_key_value_heads=config.num_cross_attention_key_value_heads, dropout=config.attention_dropout, is_decoder=True, bias=False, rope_embeddings=config.rope_embeddings, layer_idx=layer_idx, config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, cos: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, cos=cos, sin=sin, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, cos=cos, sin=sin, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # add cross-attn to positions 1 of present_key_value tuple present_key_value = (present_key_value, cross_attn_present_key_value) # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->ParlerTTS class ParlerTTSPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = ParlerTTSDecoderConfig base_model_prefix = "model" supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True _no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"] _supports_cache_class = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() MUSICGEN_START_DOCSTRING = r""" The ParlerTTS model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an encoder decoder transformer trained on the task of conditional music generation This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`ParlerTTSConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ MUSICGEN_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks, target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as `decoder_input_ids`. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. TODO: it's passed through enc_to_dec_proj and optionnally we concat the prompt hidden states in certain cases. past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True` Two formats are allowed: - An [`~cache_utils.EncoderDecoderCache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value of `inputs_embeds`. prompt_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input prompt sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) prompt_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding prompt token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) prompt_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache in the correct position and to infer the complete sequence length. """ MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. [What are input IDs?](../glossary#input-ids) The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as `input_ids`. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) prompt_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of prompt hidden-states at the output of the initial embedding layer. Concatenated to the input embeds. prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): Mask to avoid performing cross-attention on padding tokens indices of prompt input_ids. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing cross-attention on hidden heads. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`] """ def __init__(self, config: ParlerTTSDecoderConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.layerdrop self.max_target_positions = config.max_position_embeddings self.d_model = config.hidden_size self.num_codebooks = config.num_codebooks self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 # TODO(YL): actually doesn't need the +1 if initialized correctly. Too late to change now. embed_dim = config.vocab_size + 1 # + 1 for pad token id self.embed_tokens = nn.ModuleList( [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] ) self.rope_embeddings = config.rope_embeddings if not config.rope_embeddings: self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( config.max_position_embeddings, config.hidden_size, ) else: self.rotary_emb = ParlerTTSRotaryEmbedding( config.hidden_size // config.num_attention_heads, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) self.layers = nn.ModuleList( [ParlerTTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation encoder_attn_implementation = config._attn_implementation if config.cross_attention_implementation_strategy is not None: encoder_attn_implementation = ( "sdpa" if config.cross_attention_implementation_strategy == "always_sdpa" else "eager" ) self.encoder_attn_implementation = encoder_attn_implementation self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position=None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len) input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) bsz, num_codebooks, seq_len = input.shape input_shape = (bsz, seq_len) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1:] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) prepended_sequence_length = 0 # if prompt_hidden_states, fuse to inputs_embeds and update input shape if prompt_hidden_states is not None: prepended_sequence_length = prompt_hidden_states.shape[-2] inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) return_legacy_cache = False return_self_attention_cache = False if use_cache or past_key_values is not None: if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) elif not isinstance(past_key_values, EncoderDecoderCache): return_legacy_cache = True logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = 0 if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: past_key_values_length = past_key_values.get_seq_length() if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + input_shape[1] + prepended_sequence_length, device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # NOTE: 1. As it is, the masked ids from the prompt will still count in the positions embeddings # NOTE: 2. we want to concatenate the prompt attention mask and the decoder attention mask # i.i.f `prompt_cross_attention=False`. ParlerTTSForConditionalGeneration's taking care of setting # `prompt_attention_mask=None` if prompt_attention_mask is not None and attention_mask is not None: attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) elif prompt_attention_mask is not None: logger.warning_once( "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." ) if past_key_values_length == 0: attention_mask = torch.cat( [ prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype), ], dim=1, ) else: # In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch # to be able to prepend the prompt attention mask. # Since we generate token per token, we can recompute the generated length from the information we have. generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 attention_mask = torch.cat( [ prompt_attention_mask, torch.ones( (input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype ), ], dim=1, ) input_shape = inputs_embeds.size()[:-1] cos, sin = None, None if not self.rope_embeddings: # embed positions # TODO: As it is, the masked ids from the prompt will still count in the positions embeddings # maybe should modify position embeddings positions = self.embed_positions(inputs_embeds, past_key_values_length) hidden_states = inputs_embeds + positions.to(inputs_embeds.device) else: hidden_states = inputs_embeds if position_ids is None: if attention_mask is not None: # masked ids will **not** count in the position embeddings position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = torch.arange( past_key_values_length, input_shape[1] + past_key_values_length, dtype=torch.long, device=inputs_embeds.device, ) position_ids = position_ids.unsqueeze(0) # Some generation methods already pass only the last input ID if position_ids.shape[1] > input_shape[1]: position_ids = position_ids[:, -input_shape[1] :] cos, sin = self.rotary_emb(hidden_states.device.type, position_ids) cos, sin = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values.self_attention_cache if past_key_values is not None else None, output_attentions, ) if encoder_hidden_states is not None and encoder_attention_mask is not None: if self.encoder_attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.encoder_attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != len(self.layers): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {attn_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): continue if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, hidden_states, causal_mask, cos, sin, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, cos=cos, sin=sin, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), past_key_value=past_key_values if use_cache else None, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = past_key_values if use_cache else None if return_self_attention_cache: next_cache = past_key_values.self_attention_cache if return_legacy_cache: next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @add_start_docstrings( "The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.", MUSICGEN_START_DOCSTRING, ) # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with Musicgen->ParlerTTS class ParlerTTSModel(ParlerTTSPreTrainedModel): def __init__(self, config: ParlerTTSDecoderConfig): super().__init__(config) self.decoder = ParlerTTSDecoder(config) self.config = config # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.decoder.embed_tokens def set_input_embeddings(self, value): self.decoder.embed_tokens = value def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states, prompt_hidden_states=prompt_hidden_states, prompt_attention_mask=prompt_attention_mask, head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) if not return_dict: return decoder_outputs return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) @add_start_docstrings( "The Parler-TTS decoder model with a language modelling head on top.", MUSICGEN_START_DOCSTRING, ) class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): def __init__(self, config: ParlerTTSDecoderConfig): super().__init__(config) self.model = ParlerTTSModel(config) self.num_codebooks = config.num_codebooks self.vocab_size = config.vocab_size self.num_codebooks = config.num_codebooks self.use_fused_lm_heads = config.use_fused_lm_heads if self.use_fused_lm_heads: self.lm_heads = nn.Linear(config.hidden_size, config.vocab_size * config.num_codebooks, bias=False) else: self.lm_heads = nn.ModuleList( [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.decoder.embed_tokens def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value def get_output_embeddings(self): return self.lm_heads def set_output_embeddings(self, new_embeddings): self.lm_heads = new_embeddings def set_decoder(self, decoder): self.model.decoder = decoder def get_decoder(self): return self.model.decoder @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=ParlerTTSCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, loss_reduction: str = "mean", ) -> Union[Tuple, ParlerTTSCausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, prompt_hidden_states=prompt_hidden_states, prompt_attention_mask=prompt_attention_mask, head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.use_fused_lm_heads: lm_logits = self.lm_heads(hidden_states).view(hidden_states.shape[0], -1, self.num_codebooks, self.vocab_size).transpose(1,2) else: lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1) loss = None per_codebook_losses = None if labels is not None: codebook_weights = self.config.codebook_weights # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels logits = lm_logits[:, :, -labels.shape[1] :] loss_fct = CrossEntropyLoss(reduction=loss_reduction) loss = torch.zeros([], device=self.device) per_codebook_losses = [] # (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks) labels = labels.masked_fill(labels == self.config.bos_token_id, -100) # we use every codebooks token AND one single EOS at the end of each codebooks mask = (input_ids.transpose(1, 2) != self.config.eos_token_id) & ((labels != -100)) # per codebook cross-entropy for codebook in range(self.config.num_codebooks): codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) codebook_mask = mask[..., codebook].contiguous().view(-1) codebook_labels = labels[..., codebook].contiguous().view(-1) codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) per_codebook_losses.append(codebook_loss) if codebook_weights is not None: codebook_loss = codebook_loss*codebook_weights[codebook] loss += codebook_loss if codebook_weights is not None: loss = loss / sum(codebook_weights) else: loss = loss / self.config.num_codebooks # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output + (per_codebook_losses, )) if loss is not None else output return ParlerTTSCausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, per_codebook_losses=per_codebook_losses, ) def prepare_inputs_for_generation( self, input_ids, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, prompt_hidden_states=None, prompt_attention_mask=None, head_mask=None, cross_attn_head_mask=None, past_key_values=None, use_cache=True, delay_pattern_mask=None, cache_position=None, inputs_embeds=None, **kwargs, ): if delay_pattern_mask is None: input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, bos_token_id=self.generation_config.bos_token_id, pad_token_id=self.generation_config.pad_token_id, max_length=self.generation_config.max_length, ) # apply the delay pattern mask input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values is not None: input_ids = input_ids[:, -1:] if position_ids is not None: position_ids = position_ids[:, -input_ids.shape[1] :] # we only want to use prompt signal in the 1st generation step but keeping the attention mask prompt_hidden_states = None return { "input_ids": input_ids.contiguous(), # `contiguous()` needed for compilation use cases "attention_mask": attention_mask, "position_ids": position_ids, "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask, "prompt_hidden_states": prompt_hidden_states, "prompt_attention_mask": prompt_attention_mask, "head_mask": head_mask, "cross_attn_head_mask": cross_attn_head_mask, "past_key_values": past_key_values, "use_cache": use_cache, "cache_position": cache_position, "inputs_embeds": inputs_embeds, } # Ignore copy def build_delay_pattern_mask( self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None ): """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, seq_len)`: - [B, -1, -1, -1, -1, P, P, P] - [B, B, -1, -1, -1, -1, P, P] - [B, B, B, -1, -1, -1, -1, P] - [B, B, B, B, -1, -1, -1, -1] where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the mask is set to the value in the prompt: - [B, a, b, -1, -1, P, P, P] - [B, B, c, d, -1, -1, P, P] - [B, B, B, e, f, -1, -1, P] - [B, B, B, B, g, h, -1, -1] where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 tokens in our prediction. """ max_length = max_length if max_length is not None else self.generation_config.max_length return build_delay_pattern_mask(input_ids, bos_token_id, pad_token_id, max_length, self.num_codebooks) @staticmethod def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): """Apply a delay pattern mask to the decoder input ids, only preserving predictions where the mask is set to -1, and otherwise setting to the value detailed in the mask.""" return apply_delay_pattern_mask(input_ids, decoder_pad_token_mask) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, synced_gpus: Optional[bool] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, ): """ Generates sequences of token ids for models with a language modeling head. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation strategies and code examples, check out the [following guide](./generation_strategies). Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateDecoderOnlyOutput`], - [`~generation.GenerateBeamDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`] """ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs` input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] // self.num_codebooks self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: self._prepare_attention_mask_for_generation( input_ids, generation_config.pad_token_id, generation_config.eos_token_id ) # 5. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=input_ids, input_ids_length=input_ids_length, ) # 6. Prepare `input_ids` which will be used for auto-regressive generation # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS) input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, pad_token_id=generation_config._decoder_start_token_tensor, max_length=generation_config.max_length, ) if streamer is not None: streamer.put(input_ids.cpu()) # stash the delay mask so that we don't have to recompute it in each forward pass model_kwargs["delay_pattern_mask"] = delay_pattern_mask # 7. determine generation mode is_greedy_gen_mode = ( (generation_config.num_beams == 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is False ) is_sample_gen_mode = ( (generation_config.num_beams == 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is True ) # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, device=input_ids.device, ) # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing greedy search, " f"but is {generation_config.num_return_sequences}." ) # 10. run greedy search outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, **model_kwargs, ) # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) else: raise ValueError( "Got incompatible mode for generation, should be one of greedy or sampling. " "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." ) if generation_config.return_dict_in_generate: output_ids = outputs.sequences else: output_ids = outputs # apply the pattern mask to the final ids output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask _, mask = self.build_delay_pattern_mask( input_ids, bos_token_id=generation_config.bos_token_id, pad_token_id=generation_config.pad_token_id, max_length=output_ids.shape[1], ) mask = (mask != generation_config._bos_token_tensor) & (mask != generation_config._pad_token_tensor) output_ids = output_ids[mask].reshape(batch_size, self.num_codebooks, -1) if generation_config.return_dict_in_generate: outputs.sequences = output_ids return outputs else: return output_ids @add_start_docstrings( "The composite Parler-TTS model with a text encoder, audio encoder and ParlerTTS decoder, " "for music generation tasks with one or both of text and audio prompts.", MUSICGEN_START_DOCSTRING, ) class ParlerTTSForConditionalGeneration(PreTrainedModel): config_class = ParlerTTSConfig base_model_prefix = "encoder_decoder" main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True def __init__( self, config: Optional[ParlerTTSConfig] = None, text_encoder: Optional[PreTrainedModel] = None, audio_encoder: Optional[PreTrainedModel] = None, decoder: Optional[ParlerTTSForCausalLM] = None, ): if config is None and (text_encoder is None or audio_encoder is None or decoder is None): raise ValueError( "Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder." ) if config is None: config = ParlerTTSConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config) else: if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size: raise ValueError( "If `cross_attention_hidden_size` is specified in the Parler-TTS decoder's configuration, it has to be equal" f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for" " `config.text_encoder.hidden_size`." ) # initialize with config super().__init__(config) if text_encoder is None: from transformers.models.auto.modeling_auto import AutoModelForTextEncoding text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) if audio_encoder is None: from transformers.models.auto.modeling_auto import AutoModel audio_encoder = AutoModel.from_config(config.audio_encoder) if decoder is None: decoder = ParlerTTSForCausalLM._from_config(config.decoder) self.text_encoder = text_encoder self.audio_encoder = audio_encoder self.decoder = decoder if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict(): logger.warning( f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:" f" {self.config.text_encoder}" ) if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict(): logger.warning( f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:" f" {self.config.audio_encoder}" ) if self.decoder.config.to_dict() != self.config.decoder.to_dict(): logger.warning( f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" f" {self.config.decoder}" ) # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation self.config.decoder._attn_implementation = self.decoder.config._attn_implementation self.text_encoder.config = self.config.text_encoder self.audio_encoder.config = self.config.audio_encoder self.decoder.config = self.config.decoder # text encoder outputs might need to be projected to different dimension for decoder if ( self.text_encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) # prompt embeddings self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) self.prompt_cross_attention = config.prompt_cross_attention if config.prompt_cross_attention: self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( config.decoder.max_position_embeddings, config.decoder.hidden_size, ) if self.text_encoder.get_output_embeddings() is not None: raise ValueError( f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" ) decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) if "encoder_hidden_states" not in decoder_signature: raise ValueError( "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" ) audio_encoder_signature = set(inspect.signature(self.audio_encoder.decode).parameters.keys()) self.use_audio_scales = "audio_scales" in audio_encoder_signature self.use_4dim_audio_codes = False audio_type = audio_encoder.config.model_type if audio_type in {"encodec", "dac_on_the_hub"} or (audio_type == "dac" and not is_dac_integrated_to_transformers): self.use_4dim_audio_codes = True # Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly self.post_init() def _init_weights(self, module): std = self.decoder.config.initializer_factor if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def tie_weights(self): # tie text encoder & decoder if needed if self.config.tie_encoder_decoder: # tie text encoder and decoder base model decoder_base_model_prefix = self.decoder.base_model_prefix self._tie_encoder_decoder_weights( self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) def get_audio_encoder(self): return self.audio_encoder def get_text_encoder(self): return self.text_encoder def get_encoder(self): # get the text encoder to compute the encoder hidden-states for generation return self.get_text_encoder() def get_decoder(self): return self.decoder def get_input_embeddings(self): return self.text_encoder.get_input_embeddings() def get_output_embeddings(self): return self.decoder.get_output_embeddings() def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r""" Example: ```python >>> from parler_tts import ParlerTTSForConditionalGeneration >>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1") ```""" # At the moment fast initialization is not supported for composite models if kwargs.get("_fast_init", False): logger.warning( "Fast initialization is currently not supported for ParlerTTSForConditionalGeneration. " "Falling back to slow initialization..." ) kwargs["_fast_init"] = False return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @classmethod def from_sub_models_pretrained( cls, text_encoder_pretrained_model_name_or_path: str = None, audio_encoder_pretrained_model_name_or_path: str = None, decoder_pretrained_model_name_or_path: str = None, *model_args, **kwargs, ) -> PreTrainedModel: r""" Instantiate a text encoder, an audio encoder, and a Parler-TTS decoder from one, two or three base classes of the library from pretrained model checkpoints. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you need to first set it back in training mode with `model.train()`. Params: text_encoder_pretrained_model_name_or_path (`str`, *optional*): Information necessary to initiate the text encoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `t5-base`, or namespaced under a user or organization name, like `google/flan-t5-base. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. audio_encoder_pretrained_model_name_or_path (`str`, *optional*): Information necessary to initiate the audio encoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `facebook/encodec_24khz`. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): Information necessary to initiate the decoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or organization name, like `parler-tts/parler-tts-mini-v1`. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. model_args (remaining positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration parameter. - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration parameter. - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - To update the parent model configuration, do not use a prefix for each configuration parameter. Behaves differently depending on whether a `config` is provided or automatically loaded. Example: ```python >>> from parler_tts import ParlerTTSForConditionalGeneration >>> # initialize a parler_tts model from a t5 text encoder, encodec audio encoder, and parler_tts decoder >>> model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( ... text_encoder_pretrained_model_name_or_path="t5-base", ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz", ... decoder_pretrained_model_name_or_path="parler-tts/parler-tts-mini-v1", ... ) >>> # saving model after fine-tuning >>> model.save_pretrained("./parler_tts-ft") >>> # load fine-tuned model >>> model = ParlerTTSForConditionalGeneration.from_pretrained("./parler_tts-ft") ```""" kwargs_text_encoder = { argument[len("text_encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("text_encoder_") } kwargs_audio_encoder = { argument[len("audio_encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("audio_encoder_") } kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } # remove text encoder, audio encoder and decoder kwargs from kwargs for key in kwargs_text_encoder.keys(): del kwargs["text_encoder_" + key] for key in kwargs_audio_encoder.keys(): del kwargs["audio_encoder_" + key] for key in kwargs_decoder.keys(): del kwargs["decoder_" + key] # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made # by the value of the flag `is_decoder` that we need to set correctly. text_encoder = kwargs_text_encoder.pop("model", None) if text_encoder is None: if text_encoder_pretrained_model_name_or_path is None: raise ValueError( "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_text_encoder: encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained( text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model " "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False kwargs_text_encoder["config"] = encoder_config text_encoder = AutoModelForTextEncoding.from_pretrained( text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder ) audio_encoder = kwargs_audio_encoder.pop("model", None) if audio_encoder is None: if audio_encoder_pretrained_model_name_or_path is None: raise ValueError( "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_audio_encoder: encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained( audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model " "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False kwargs_audio_encoder["config"] = encoder_config audio_encoder = AutoModel.from_pretrained( audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder ) decoder = kwargs_decoder.pop("model", None) if decoder is None: if decoder_pretrained_model_name_or_path is None: raise ValueError( "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_decoder: decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained( decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True ) if isinstance(decoder_config, ParlerTTSConfig): decoder_config = decoder_config.decoder if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True kwargs_decoder["config"] = decoder_config if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a " "`decoder_config` to `.from_sub_models_pretrained(...)`" ) decoder = ParlerTTSForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) # instantiate config with corresponding kwargs config = ParlerTTSConfig.from_sub_models_config( text_encoder.config, audio_encoder.config, decoder.config, **kwargs ) return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) @add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=ParlerTTSSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, input_values: Optional[torch.FloatTensor] = None, padding_mask: Optional[torch.BoolTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, prompt_input_ids: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.LongTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, loss_reduction: str = "mean", **kwargs, ) -> Union[Tuple, ParlerTTSSeq2SeqLMOutput]: r""" Returns: Examples: ```python >>> from transformers import AutoProcessor, ParlerTTSForConditionalGeneration >>> import torch >>> processor = AutoProcessor.from_pretrained("parler-tts/parler-tts-mini-v1") >>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1") >>> inputs = processor( ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], ... padding=True, ... return_tensors="pt", ... ) >>> pad_token_id = model.generation_config.pad_token_id >>> decoder_input_ids = ( ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) ... * pad_token_id ... ) >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size) torch.Size([8, 1, 2048]) ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_text_encoder = { argument[len("text_encoder_")]: value for argument, value in kwargs.items() if argument.startswith("text_encoder_") } kwargs_audio_encoder = { argument[len("audio_encoder_")]: value for argument, value in kwargs.items() if argument.startswith("audio_encoder_") } kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if prompt_hidden_states is None: if prompt_input_ids is not None: prompt_hidden_states = self.embed_prompts(prompt_input_ids) if encoder_outputs is None: encoder_outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_text_encoder, ) encoder_hidden_states = encoder_outputs[0] # optionally project encoder_hidden_states if ( self.text_encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) if attention_mask is not None: encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] if prompt_hidden_states is not None and self.prompt_cross_attention: # add sinusoidal positional embedding positions = self.embed_positions(prompt_hidden_states, 0) prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) if prompt_attention_mask is not None and attention_mask is None: attention_mask = torch.ones( encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype ) elif attention_mask is not None and prompt_attention_mask is None: prompt_attention_mask = torch.ones( prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype ) # concatenate text description states with prompt description states encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) if prompt_attention_mask is not None: attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) prompt_hidden_states = None prompt_attention_mask = None encoder_outputs["last_hidden_state"] = encoder_hidden_states elif isinstance(encoder_outputs, tuple): encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs.last_hidden_state if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ).transpose(1, 2) elif decoder_input_ids is None and decoder_inputs_embeds is None: audio_encoder_outputs = self.audio_encoder( input_values=input_values, padding_mask=padding_mask, **kwargs_audio_encoder, ) audio_codes = audio_encoder_outputs.audio_codes frames, bsz, codebooks, seq_len = audio_codes.shape if frames != 1: raise ValueError( f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " "disabled by setting `chunk_length=None` in the audio encoder." ) if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2: # mono input through encodec that we convert to stereo audio_codes = audio_codes.repeat_interleave(2, dim=2) decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, prompt_hidden_states=prompt_hidden_states, prompt_attention_mask=prompt_attention_mask, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, labels=labels, cache_position=cache_position, loss_reduction=loss_reduction, **kwargs_decoder, ) if not return_dict: return decoder_outputs + (encoder_hidden_states,) return ParlerTTSSeq2SeqLMOutput( loss=decoder_outputs.loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, per_codebook_losses=decoder_outputs.per_codebook_losses, ) def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, attention_mask=None, head_mask=None, decoder_attention_mask=None, decoder_head_mask=None, prompt_hidden_states=None, prompt_attention_mask=None, cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, decoder_delay_pattern_mask=None, cache_position=None, inputs_embeds=None, **kwargs, ): if decoder_delay_pattern_mask is None: decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( decoder_input_ids, bos_token_id=self.generation_config.bos_token_id, pad_token_id=self.generation_config.pad_token_id, max_length=self.generation_config.max_length, ) # apply the delay pattern mask decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) past_length = 0 if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() if past_key_values.get_seq_length() > 0: # we only want to use prompt signal in the 1st generation step prompt_hidden_states = None else: past_length = past_key_values[0][0].shape[2] # we only want to use prompt signal in the 1st generation step prompt_hidden_states = None # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = decoder_input_ids.shape[1] - 1 decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] if cache_position is None: cache_position = torch.arange( past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device ) elif use_cache: cur_len = decoder_input_ids.shape[1] if prompt_hidden_states is not None and not self.prompt_cross_attention: # meaning we are in 1st generation step and prompt_hidden_state will be prepended cur_len += prompt_hidden_states.shape[1] cache_position = cache_position[-cur_len:] if decoder_attention_mask is None and prompt_attention_mask is not None: input = decoder_input_ids.reshape(-1, self.decoder.num_codebooks, decoder_input_ids.shape[-1]) bsz, _, seq_len = input.shape input_shape = (bsz, seq_len) past_key_values_length = 0 if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: past_key_values_length = past_key_values.get_seq_length() logger.warning_once( "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." ) if past_key_values is None or ( isinstance(past_key_values, EncoderDecoderCache) and past_key_values.get_seq_length() == 0 ): decoder_attention_mask = torch.ones(input_shape, device=self.device, dtype=decoder_input_ids.dtype) elif prompt_attention_mask is not None: # In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch # to be able to prepend the prompt attention mask. # Since we generate token per token, we can recompute the generated length from the information we have. generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 decoder_attention_mask = torch.ones( (input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype ) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids.contiguous(), "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "prompt_hidden_states": prompt_hidden_states, "prompt_attention_mask": prompt_attention_mask, "use_cache": use_cache, "cache_position": cache_position, "inputs_embeds": inputs_embeds, } def _prepare_decoder_input_ids_for_generation( self, batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], decoder_start_token_id: int = None, bos_token_id: int = None, device: torch.device = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. if model_kwargs is not None and "decoder_input_ids" in model_kwargs: decoder_input_ids = model_kwargs.pop("decoder_input_ids") elif "input_ids" in model_kwargs and model_input_name != "input_ids": decoder_input_ids = model_kwargs.pop("input_ids") else: decoder_input_ids = None # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device decoder_input_ids_start = ( torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device) * decoder_start_token_id ) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: decoder_input_ids = decoder_input_ids_start # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item(): decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask if not self.prompt_cross_attention: prompt_hidden_states = model_kwargs["prompt_hidden_states"] num_codebooks = self.decoder.num_codebooks input = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1]) inputs_embeds = sum( [ self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks) ] ) inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) model_kwargs["inputs_embeds"] = inputs_embeds return decoder_input_ids, model_kwargs def _prepare_text_encoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str], generation_config: GenerationConfig, ) -> Dict[str, Any]: # 1. get text encoder encoder = self.get_text_encoder() # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): encoder._hf_hook.io_same_device = True # 2. Prepare encoder args and encoder kwargs from model kwargs. irrelevant_prefix = ["decoder_", "cross_attn", "prompt_", "use_cache", "labels"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } encoder_signature = set(inspect.signature(encoder.forward).parameters) encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature if not encoder_accepts_wildcard: encoder_kwargs = { argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature } encoder_kwargs["output_attentions"] = generation_config.output_attentions encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor last_hidden_state = encoder(**encoder_kwargs).last_hidden_state # we optionnally project last_hidden_state to avoid recomputing every time encoder_hidden_states = last_hidden_state if ( self.text_encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) if model_kwargs["attention_mask"] is not None: encoder_hidden_states = encoder_hidden_states * model_kwargs["attention_mask"][..., None] model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=encoder_hidden_states) return model_kwargs def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs): prompt_hidden_states = self.embed_prompts(prompt_input_ids) if self.prompt_cross_attention: # add sinusoidal positional embedding positions = self.embed_positions(prompt_hidden_states, 0) prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) attention_mask = model_kwargs.get("attention_mask", None) prompt_attention_mask = model_kwargs.get("prompt_attention_mask", None) encoder_hidden_states = model_kwargs["encoder_outputs"].last_hidden_state if prompt_attention_mask is not None and attention_mask is None: attention_mask = torch.ones( encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype ) elif attention_mask is not None and prompt_attention_mask is None: prompt_attention_mask = torch.ones( prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype ) # concatenate text description states with prompt description states encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) if prompt_attention_mask is not None: attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) model_kwargs["encoder_outputs"].last_hidden_state = encoder_hidden_states model_kwargs["attention_mask"] = attention_mask # in this case, since we already concatenated the prompt hidden states and attention mask, we don't need them anymore. model_kwargs["prompt_hidden_states"] = None model_kwargs["prompt_attention_mask"] = None else: model_kwargs["prompt_hidden_states"] = prompt_hidden_states # we're keeping the prompt attention mask because it has to be prepended to the decoder attention mask on the fly return model_kwargs def _prepare_audio_encoder_kwargs_for_generation( self, input_values, model_kwargs, model_input_name: Optional[str] = None ): # 1. get audio encoder encoder = self.get_audio_encoder() # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): encoder._hf_hook.io_same_device = True # 2. Prepare encoder args and encoder kwargs from model kwargs. irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } encoder_signature = set(inspect.signature(encoder.forward).parameters) encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature if not encoder_accepts_wildcard: encoder_kwargs = { argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature } # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name encoder_kwargs["return_dict"] = True if "num_quantizers" in encoder_signature: encoder_kwargs["num_quantizers"] = self.config.decoder.num_codebooks elif "num_codebooks" in encoder_signature: encoder_kwargs["num_codebooks"] = self.config.decoder.num_codebooks elif "n_quantizers" in encoder_signature: encoder_kwargs["n_quantizers"] = self.config.decoder.num_codebooks encoder_kwargs[model_input_name] = input_values audio_encoder_outputs = encoder.encode(**encoder_kwargs) audio_codes = audio_encoder_outputs.audio_codes audio_scales = audio_encoder_outputs.get("audio_scales") if audio_codes.ndim == 3: bsz, codebooks, seq_len = audio_codes.shape decoder_input_ids = audio_codes.reshape(bsz * self.decoder.num_codebooks, seq_len) else: frames, bsz, codebooks, seq_len = audio_codes.shape if frames != 1: raise ValueError( f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " "disabled by setting `chunk_length=None` in the audio encoder." ) decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) model_kwargs["decoder_input_ids"] = decoder_input_ids if audio_scales is not None: model_kwargs["audio_scales"] = audio_scales return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right( labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id ).transpose(1, 2) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" " model.decoder.resize_token_embeddings(...))" ) def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[torch.Tensor] = None, bos_token_id: Optional[int] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: """Initializes input ids for generation, if necessary.""" if inputs is not None: return inputs encoder_outputs = model_kwargs.get("encoder_outputs") if encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs[0].size()[:-1] return torch.ones(shape, dtype=torch.long, device=self.device) * -100 if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with # soft-prompting or in multimodal implementations built on top of decoder-only language models. batch_size = 1 for value in model_kwargs.values(): if isinstance(value, torch.Tensor): batch_size = value.shape[0] break return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id def _get_decoder_start_token_id( self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None ) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.generation_config.decoder_start_token_id ) bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id elif bos_token_id is not None: return bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a new `generate` call requires a larger cache. Returns the resulting cache object. """ cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) if hasattr(self, "_cache"): cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache if cache_implementation == "sliding_window": max_cache_len = min(self.config.sliding_window, max_cache_len) need_new_cache = ( not hasattr(self, "_cache") or (not isinstance(cache_to_check, cache_cls)) or cache_to_check.max_batch_size != max_batch_size or cache_to_check.max_cache_len < max_cache_len ) if requires_cross_attention_cache and hasattr(self, "_cache"): need_new_cache = ( need_new_cache or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1] ) if need_new_cache: if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: cache_dtype = self.dtype cache_kwargs = { "config": self.config.decoder, "max_batch_size": max_batch_size, "max_cache_len": max_cache_len, "device": self.device, "dtype": cache_dtype, } self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy() encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] config_cross_attention_cache = copy.deepcopy(self.config.decoder) config_cross_attention_cache.update( {"num_key_value_heads": self.config.decoder.num_cross_attention_key_value_heads} ) encoder_kwargs["config"] = config_cross_attention_cache self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs)) else: self._cache.reset() return self._cache def freeze_encoders(self, freeze_text_encoder=True): if freeze_text_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False self.text_encoder._requires_grad = False for param in self.audio_encoder.parameters(): param.requires_grad = False self.audio_encoder._requires_grad = False @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, synced_gpus: Optional[bool] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, ): """ Generates sequences of token ids for models with a language modeling head. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation strategies and code examples, check out the [following guide](./generation_strategies). Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateDecoderOnlyOutput`], - [`~generation.GenerateBeamDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GenerateEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`] """ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple: # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0]) # 2. Set generation parameters if not already defined requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList([ParlerTTSLogitsProcessor(generation_config.eos_token_id, self.decoder.num_codebooks, batch_size, inputs_tensor.device)]) stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) if "encoder_outputs" not in model_kwargs: # encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_text_encoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: # `prompt_hidden_states` are created and added to `model_kwargs` model_kwargs = self._prepare_prompt_kwargs_for_generation( model_kwargs["prompt_input_ids"], model_kwargs, ) if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( model_kwargs["input_values"], model_kwargs, ) # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config._decoder_start_token_tensor, bos_token_id=generation_config._bos_token_tensor, device=inputs_tensor.device, ) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: raise ValueError( "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " "Cache object) is unsupported. Please use only one of the two." ) elif generation_config.cache_implementation is not None: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static" and not self._supports_static_cache: raise ValueError( "This model does not support `cache_implementation='static'`. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981" ) if not self.prompt_cross_attention: # when we prepend prompt_hidden_state to inputs_embeds, max_cache_len needs to be actualised # generation_config.max_length has already been increased by input_ids_length which is # already counted in input_embeds_seq_length so we remove it input_embeds_seq_length = model_kwargs["inputs_embeds"].shape[1] max_cache_len = generation_config.max_length + input_embeds_seq_length - input_ids_length else: max_cache_len = self.generation_config.max_length model_kwargs["past_key_values"] = self._get_cache( generation_config.cache_implementation, getattr(generation_config, "num_beams", 1) * batch_size, max_cache_len, model_kwargs, ) elif generation_config.cache_implementation == "quantized": raise ValueError( "This model does not support the quantized cache. If you want your model to support quantized " "cache, please open an issue on the Parler-TTS repository https://github.com/huggingface/parler-tts" ) # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): past = model_kwargs.get("past_key_values", None) requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) if past is None: model_kwargs["past_key_values"] = ( DynamicCache() if not requires_cross_attention_cache else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) elif isinstance(past, tuple): model_kwargs["past_key_values"] = ( DynamicCache.from_legacy_cache(past) if not requires_cross_attention_cache else EncoderDecoderCache.from_legacy_cache(past) ) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS) delayed_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, bos_token_id=generation_config._bos_token_tensor, pad_token_id=generation_config._pad_token_tensor, max_length=generation_config.max_length, ) # stash the delay mask so that we don't have to recompute in each forward pass model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask # input_ids are ready to be placed on the streamer (if used) if streamer is not None: streamer.put(delayed_input_ids.cpu()) # 7. determine generation mode generation_mode = generation_config.get_generation_mode() # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, device=delayed_input_ids.device, ) # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # expand input_ids with `num_return_sequences` additional sequences per batch delayed_input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=delayed_input_ids, expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 10. run sample outputs = self._sample( delayed_input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) else: raise ValueError( "Got incompatible mode for generation, should be one of greedy or sampling. " "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." ) if generation_config.return_dict_in_generate: output_ids = outputs.sequences else: output_ids = outputs # Apply the pattern mask to the final ids output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask _, mask = self.decoder.build_delay_pattern_mask( input_ids, bos_token_id=generation_config.bos_token_id, pad_token_id=generation_config.pad_token_id, max_length=output_ids.shape[1], ) mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id) output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1) # append the frame dimension back to the audio codes output_ids = output_ids[None, ...] audio_decode_kwargs = {} if self.use_audio_scales: audio_scales = model_kwargs.get("audio_scales") if audio_scales is None: audio_scales = [None] * batch_size audio_decode_kwargs["audio_scales"] = audio_scales if not self.use_4dim_audio_codes: # remove chunk dim output_ids = output_ids.squeeze(0) decode_sequentially = ( generation_config.bos_token_id in output_ids or generation_config.pad_token_id in output_ids or generation_config.eos_token_id in output_ids ) if not decode_sequentially: output_values = self.audio_encoder.decode( audio_codes=output_ids, **audio_decode_kwargs, ).audio_values.squeeze(1) output_lengths = [audio.shape[0] for audio in output_values] else: output_values = [] for sample_id in range(batch_size): sample = output_ids[:, sample_id] if self.use_4dim_audio_codes else output_ids[sample_id] sample_mask = (sample >= self.audio_encoder.config.codebook_size) sample_mask = (sample_mask.sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else (sample_mask.sum(dim=0) == 0) single_audio_decode_kwargs = {} if self.use_audio_scales: single_audio_decode_kwargs["audio_scales"] = [audio_decode_kwargs["audio_scales"][sample_id]] if sample_mask.sum() > 0: sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask] sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **single_audio_decode_kwargs).audio_values sample = sample if sample.ndim == 3 else sample.unsqueeze(0) output_values.append(sample.transpose(0, 2)) else: output_values.append(torch.zeros((1, 1, 1)).to(self.device)) output_lengths = [audio.shape[0] for audio in output_values] output_values = ( torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0) .squeeze(-1) .squeeze(-1) ) if generation_config.return_dict_in_generate: outputs["audios_length"] = output_lengths outputs.sequences = output_values return outputs else: return output_values def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` if "inputs_embeds" in model_kwargs: cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 else: cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 past_length = 0 if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] past_length = 0 if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, # end-to-end compilation will yield bad results because `cache_position` will be incorrect. if not is_torchdynamo_compiling(): cache_position = cache_position[past_length:] model_kwargs["cache_position"] = cache_position return model_kwargs