diff --git "a/modeling_kosmos2_5.py" "b/modeling_kosmos2_5.py" --- "a/modeling_kosmos2_5.py" +++ "b/modeling_kosmos2_5.py" @@ -1,2622 +1,3 @@ -# coding=utf-8 -# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch KOSMOS-2.5 model.""" - -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPooling, - CausalLMOutputWithCrossAttentions, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_kosmos2_5 import ( - Kosmos2_5Config, - Kosmos2_5TextConfig, - Kosmos2_5VisionConfig, -) - - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = Kosmos2_5Config - - -# Copied from ...models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), -100.0) - - -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), -100.0, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from ...models.roberta.modeling_roberta.create_position_ids_from_input_ids -def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ - # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask - return incremental_indices.long() + padding_idx - - -KOSMOS2_5_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Kosmos2_5Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -KOSMOS2_5_VISION_INPUTS_DOCSTRING = r""" - Args: - flattened_patches (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`Kosmos2_5ImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -KOSMOS2_5_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - image_embeds: (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): - Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. - image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, - 1]`: - - - 1 for places where to put the image features, - - 0 for places that are not for image features (i.e. for text tokens). - - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -KOSMOS2_5_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See - [`Kosmos2_5ImageProcessor.__call__`] for details. - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, - 1]`: - - - 1 for places where to put the image features, - - 0 for places that are not for image features (i.e. for text tokens). - - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - image_embeds: (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): - Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@dataclass -class Kosmos2_5ModelOutput(ModelOutput): - """ - Base class for text model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): - Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. - projection_attentions (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute - the weighted average in the self-attention heads. - vision_model_output(`BaseModelOutputWithPooling`, *optional*): - The output of the [`Kosmos2VisionModel`]. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_embeds: Optional[torch.FloatTensor] = None - projection_attentions: Optional[Tuple[torch.FloatTensor]] = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> Tuple[Any]: - return tuple( - (self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()) - for k in self.keys() - ) - - -@dataclass -class Kosmos2_5ForConditionalGenerationModelOutput(ModelOutput): - """ - Model output class for `Kosmos2ForConditionalGeneration`. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): - Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. - projection_attentions (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute - the weighted average in the self-attention heads. - vision_model_output(`BaseModelOutputWithPooling`, *optional*): - The output of the [`Kosmos2VisionModel`]. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_embeds: Optional[torch.FloatTensor] = None - projection_attentions: Optional[Tuple[torch.FloatTensor]] = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> Tuple[Any]: - return tuple( - (self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()) - for k in self.keys() - ) - - -# Copied from ...models.pix2struct.modeling_pix2struct.Pix2StructLayerNorm -> Kosmos2_5LayerNorm -class Kosmos2_5LayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -try: - from apex.normalization import FusedRMSNorm - - Kosmos2_5LayerNorm = FusedRMSNorm # noqa - - logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Kosmos2_5LayerNorm") -except ImportError: - # using the normal Kosmos2_5LayerNorm - pass -except Exception: - logger.warning("Discovered apex but it failed to load, falling back to Kosmos2_5LayerNorm") - pass - - -# Copied from ...models.pix2struct.modeling_pix2struct.Pix2StructVisionEmbeddings -> Kosmos2_5VisionEmbeddings -class Kosmos2_5VisionEmbeddings(nn.Module): - def __init__(self, config: Kosmos2_5VisionConfig) -> None: - super().__init__() - self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size) - - self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size) - self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size) - - self.dropout = nn.Dropout(config.dropout_rate, inplace=False) - - def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor: - # the row and column indices are stored in the first and second position of the flattened_patches - # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2 - row_indices = flattened_patches[:, :, 0].long() - col_indices = flattened_patches[:, :, 1].long() - - flattened_patches = flattened_patches[:, :, 2:] - - embeddings = self.patch_projection(flattened_patches) - row_embeddings = self.row_embedder(row_indices) - col_embeddings = self.column_embedder(col_indices) - - # sum all embeddings together - embeddings = embeddings + row_embeddings + col_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings - - -# Copied from ...models.pix2struct.modeling_pix2struct.Pix2StructVisionMlp -> Kosmos2_5VisionMlp -class Kosmos2_5VisionMlp(nn.Module): - def __init__(self, config: Kosmos2_5VisionConfig): - super().__init__() - self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) - self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ACT2FN[config.dense_act_fn] - - def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. - # See https://github.com/huggingface/transformers/issues/20287 - # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - if ( - isinstance(self.wo.weight, torch.Tensor) - and hidden_states.dtype != self.wo.weight.dtype - and self.wo.weight.dtype != torch.int8 - ): - hidden_states = hidden_states.to(self.wo.weight.dtype) - - hidden_states = self.wo(hidden_states) - return hidden_states - - -# Copied from ...models.pix2struct.modeling_pix2struct.Pix2StructVisionAttention -> Kosmos2_5VisionAttention -class Kosmos2_5VisionAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_attention_heads - self.dropout = config.attention_dropout - self.inner_dim = self.n_heads * self.key_value_proj_dim - self.is_causal = False - - # Mesh TensorFlow initialization to avoid scaling before softmax - self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - output_attentions=False, - ): - """ - Self-attention block - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - batch_size, seq_length, _ = hidden_states.size() - - query_states = self.query(hidden_states) - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - # get query states - # (batch_size, n_heads, seq_length, dim_per_head) - query_states = query_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - key_states = key_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - value_states = value_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.key_value_proj_dim) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, seq_length, -1) - attn_output = self.output(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - -class Kosmos2_5VisionFlashAttention2(Kosmos2_5VisionAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - output_attentions=False, - ): - """ - Flash attn Self-attention block - """ - output_attentions = False - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - batch_size, seq_length, _ = hidden_states.size() - # (batch_size, seq_length, inner_dim) - query_states = self.query(hidden_states) - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - # (batch_size, seq_length, self.n_heads , self.key_value_proj_dim) - query_states = query_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) - key_states = key_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) - value_states = value_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - seq_length, - dropout=self.dropout, - ) - attn_output = attn_output.view(batch_size, -1, self.inner_dim) - attn_output = self.output(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - ( - query_states, - key_states, - value_states, - indices_q, - cu_seq_lens, - max_seq_lens, - ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), - indices_k, - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class Kosmos2_5VisionSdpaAttention(Kosmos2_5VisionAttention): - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - output_attentions=False, - ): - if output_attentions: - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - batch_size, seq_length, _ = hidden_states.size() - - query_states = self.query(hidden_states) - key_states = self.key(hidden_states) - value_states = self.value(hidden_states) - - query_states = query_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - key_states = key_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - value_states = value_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - causal_mask = attention_mask - if attention_mask is not None: - # Slice the causal_mask to match key_states' last dimension - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and seq_length > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, seq_length, -1) - - attn_output = self.output(attn_output) - - return attn_output, None - - -KOSMOS2_5_VISION_ATTENTION_CLASSES = { - "eager": Kosmos2_5VisionAttention, - "flash_attention_2": Kosmos2_5VisionFlashAttention2, - "sdpa": Kosmos2_5VisionSdpaAttention, -} - - -# Copied from ...models.pix2struct.modeling_pix2struct.Pix2StructVisionLayer -> Kosmos2_5VisionLayer -class Kosmos2_5VisionLayer(nn.Module): - def __init__(self, config: Kosmos2_5VisionConfig) -> None: - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.config = config - self.attention = KOSMOS2_5_VISION_ATTENTION_CLASSES[config._attn_implementation](config) - self.mlp = Kosmos2_5VisionMlp(config) - self.pre_mlp_layer_norm = Kosmos2_5LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pre_attention_layer_norm = Kosmos2_5LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def _prepare_attention_mask(self, attention_mask, input_shape, inputs_embeds): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - return expanded_attn_mask - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - residual = hidden_states - - # in Kosmos2_5Vision, layernorm is applied before self-attention - hidden_states = self.pre_attention_layer_norm(hidden_states) - attention_mask = self._prepare_attention_mask(attention_mask, hidden_states.shape[:2], hidden_states) - self_attention_outputs, _ = self.attention( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=head_mask, - output_attentions=output_attentions, - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - # first residual connection - hidden_states = attention_output + residual - - # in Kosmos2_5Vision, layernorm is also applied after self-attention - layer_output = self.pre_mlp_layer_norm(hidden_states) - layer_output = self.mlp(layer_output) + hidden_states # second residual connection - return layer_output, outputs - - -# Copied from ...models.pix2struct.modeling_pix2struct.Pix2StructVisionEncoder -> Kosmos2_5VisionEncoder -class Kosmos2_5VisionEncoder(nn.Module): - def __init__(self, config: Kosmos2_5VisionConfig) -> None: - super().__init__() - self.config = config - self.layer = nn.ModuleList([Kosmos2_5VisionLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> Union[tuple, BaseModelOutput]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -# Pix2StructVisionModel -> Kosmos2_5VisionModel -class Kosmos2_5VisionModel(PreTrainedModel): - _supports_flash_attn_2 = True - _supports_sdpa = True - - def __init__(self, config: Kosmos2_5VisionConfig): - super().__init__(config) - self.config = config - self.embeddings = Kosmos2_5VisionEmbeddings(config) - self.encoder = Kosmos2_5VisionEncoder(config) - self.layernorm = Kosmos2_5LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def get_input_embeddings(self): - return self.embeddings.patch_projection - - def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def forward( - self, - flattened_patches: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if flattened_patches is None: - raise ValueError("You have to specify flattened_patches") - - if attention_mask is None: - # check where `flattened_patches` is not 0 - attention_mask = (flattened_patches.sum(dim=-1) != 0).float() - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - embedding_output = self.embeddings(flattened_patches) - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - - if not return_dict: - head_outputs = (sequence_output,) - return head_outputs + encoder_outputs[1:] - - return BaseModelOutput( - last_hidden_state=sequence_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -# Copied from ...models.kosmos2.modeling_kosmos2.Kosmos2TextSinusoidalPositionalEmbedding -> Kosmos2_5TextSinusoidalPositionalEmbedding -class Kosmos2_5TextSinusoidalPositionalEmbedding(nn.Module): - """This module produces sinusoidal positional embeddings of any length.""" - - # Copied from ...models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__ - def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): - super().__init__() - self.offset = 2 - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) - - # Copied from ...models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights - def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): - emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) - if hasattr(self, "weights"): - # in forward put the weights on the correct dtype and device of the param - emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) - - self.register_buffer("weights", emb_weights, persistent=False) - - @staticmethod - # Copied from ...models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding - def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): - """ - Build sinusoidal embeddings. - - This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of - "Attention Is All You Need". - """ - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) - if embedding_dim % 2 == 1: - # zero pad - emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) - if padding_idx is not None: - emb[padding_idx, :] = 0 - - return emb.to(torch.get_default_dtype()) - - @torch.no_grad() - def forward( - self, - input_ids: torch.Tensor = None, - inputs_embeds: torch.Tensor = None, - past_key_values_length: int = 0, - position_ids: torch.Tensor = None, - ): - if input_ids is not None: - bsz, seq_len = input_ids.size() - if position_ids is None: - # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = create_position_ids_from_input_ids( - input_ids, self.padding_idx, past_key_values_length - ).to(input_ids.device) - else: - bsz, seq_len = inputs_embeds.size()[:-1] - if position_ids is None: - position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) - - # expand embeddings if needed - max_pos = self.padding_idx + 1 + seq_len + past_key_values_length - if max_pos > self.weights.size(0): - self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) - - return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() - - # Copied from ...models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds - def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): - """ - We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. - - Args: - inputs_embeds: torch.Tensor - - Returns: torch.Tensor - """ - input_shape = inputs_embeds.size()[:-1] - sequence_length = input_shape[1] - - position_ids = torch.arange( - self.padding_idx + 1, - sequence_length + self.padding_idx + 1, - dtype=torch.long, - device=inputs_embeds.device, - ) - return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length - - -class Kosmos2_5TextFFN(nn.Module): - def __init__(self, config: Kosmos2_5TextConfig): - super().__init__() - - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - - self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim) - self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim) - - self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.ffn_layernorm(hidden_states) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - return hidden_states - - -class Kosmos2_5TextAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - # Similar to ...models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`. - def __init__( - self, - config, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - add_inner_attn_layernorm: bool = False, - bias: bool = True, - is_causal=True, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.is_causal = is_causal - - # End opy - self.inner_attn_ln = None - if add_inner_attn_layernorm: - self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - - def _shape(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = encoder_hidden_states is not None - batch_size, seq_length = hidden_states.shape[:2] - - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - key_states = self._shape(self.k_proj(current_states)) - value_states = self._shape(self.v_proj(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q_proj(hidden_states) * self.scaling) - # this weight maybe overflow with float16 - attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != ( - batch_size, - self.num_heads, - seq_length, - self.head_dim, - ): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, seq_length, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) - - if self.inner_attn_ln is not None: - attn_output = self.inner_attn_ln(attn_output) - - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Kosmos2_5TextFlashAttention2(Kosmos2_5TextAttention): - """ - Kosmos2_5 text flash attention module. This module inherits from `Kosmos2_5TextAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - is_cross_attention = encoder_hidden_states is not None - bsz, q_len, _ = hidden_states.size() - - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - key_states = self._shape(self.k_proj(current_states)).transpose(1, 2) - value_states = self._shape(self.v_proj(current_states)).transpose(1, 2) - if past_key_value is not None and not is_cross_attention: - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) - - query_states = self._shape(self.q_proj(hidden_states)).transpose(1, 2) - - if self.is_decoder: - past_key_value = (key_states, value_states) - - input_dtype = query_states.dtype - - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, None, q_len, dropout=self.dropout - ) - - attn_output = attn_output.view(bsz, -1, self.embed_dim) - - if self.inner_attn_ln is not None: - attn_output = self.inner_attn_ln(attn_output) - - attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - ( - query_states, - key_states, - value_states, - indices_q, - cu_seq_lens, - max_seq_lens, - ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), - indices_k, - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class Kosmos2_5TextSdpaAttention(Kosmos2_5TextAttention): - """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # Adapted from LlamaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - is_cross_attention = encoder_hidden_states is not None - bsz, q_len, _ = hidden_states.size() - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - key_states = self._shape(self.k_proj(current_states)) - value_states = self._shape(self.v_proj(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q_proj(hidden_states)) - - if self.is_decoder: - past_key_value = (key_states, value_states) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - is_causal = is_causal and self.is_causal - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - if self.inner_attn_ln is not None: - attn_output = self.inner_attn_ln(attn_output) - - attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -KOSMOS2_5_TEXT_ATTENTION_CLASSES = { - "eager": Kosmos2_5TextAttention, - "flash_attention_2": Kosmos2_5TextFlashAttention2, - "sdpa": Kosmos2_5TextSdpaAttention, -} - - -class Kosmos2_5TextBlock(nn.Module): - def __init__(self, config: Kosmos2_5TextConfig): - super().__init__() - self.embed_dim = config.embed_dim - self.self_attn = KOSMOS2_5_TEXT_ATTENTION_CLASSES[config._attn_implementation]( - config, - embed_dim=self.embed_dim, - num_heads=config.attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - add_inner_attn_layernorm=False, - is_causal=True, - ) - self.dropout = config.dropout - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - - if config.add_cross_attention: - self.encoder_attn = KOSMOS2_5_TEXT_ATTENTION_CLASSES[config._attn_implementation]( - config, - embed_dim=self.embed_dim, - num_heads=config.attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - add_inner_attn_layernorm=False, - is_causal=True, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - - self.ffn = Kosmos2_5TextFFN(config) - self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = True, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - - hidden_states = self.self_attn_layer_norm(hidden_states) - - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - if not hasattr(self, "encoder_attn"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - residual = hidden_states - - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - - hidden_states = self.final_layer_norm(hidden_states) - - # FFN - hidden_states = self.ffn(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class Kosmos2_5TextTransformer(nn.Module): - """ - Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2_5TextBlock`]. - - Args: - config: Kosmos2_5TextConfig - """ - - def __init__(self, config: Kosmos2_5TextConfig): - super().__init__() - self.config = config - self.dropout = config.dropout - self.layerdrop = config.layerdrop - - self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id) - - self.embed_positions = Kosmos2_5TextSinusoidalPositionalEmbedding( - num_positions=config.max_position_embeddings, - embedding_dim=config.embed_dim, - padding_idx=config.pad_token_id, - ) - self.segment_emb = nn.Embedding(2, config.embed_dim) - - self.layers = nn.ModuleList([Kosmos2_5TextBlock(config) for _ in range(config.layers)]) - self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) - - self.gradient_checkpointing = False - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward_embedding( - self, - input_ids, - inputs_embeds: torch.Tensor = None, - image_embeds: torch.Tensor = None, - img_input_mask: torch.Tensor = None, - past_key_values_length: int = 0, - position_ids: torch.Tensor = None, - ): - # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`. - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if image_embeds is not None: - inputs_embeds[img_input_mask == 1] = image_embeds.to(inputs_embeds.device).view(-1, image_embeds.size(-1)) - inputs_embeds = inputs_embeds * self.embed_scale - # embed positions - positions = None - if self.embed_positions is not None: - positions = self.embed_positions( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - position_ids=position_ids, - ).to(inputs_embeds.device) - if self.segment_emb is not None: - if img_input_mask is not None: - # make every not equal 0 be 1 - img_input_mask = img_input_mask.ne(0).long() - segment_embeds = self.segment_emb(img_input_mask) - positions += segment_embeds - else: - # add zero embedding for padding tokens - bsz, seq_len, dim = positions.size() - zero_emb = self.segment_emb(torch.zeros((bsz, 1), dtype=torch.long, device=positions.device)) - positions += zero_emb - - if positions is not None: - hidden_states = inputs_embeds + positions - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - return hidden_states - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, - image_embeds_position_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.shape - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - # We don't need img info. when `past_key_values_length` > 0 - if past_key_values_length > 0: - image_embeds = None - image_embeds_position_mask = None - - hidden_states = self.forward_embedding( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - image_embeds=image_embeds, - img_input_mask=image_embeds_position_mask, - past_key_values_length=past_key_values_length, - position_ids=position_ids, - ) - - # print(hidden_states.shape) - # print(hidden_states) - causal_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, hidden_states, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_value_states = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - if use_cache: - present_key_value_states += (layer_outputs[3 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add final layer norm - hidden_states = self.layer_norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_self_attns, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class Kosmos2_5ImageToTextProjection(nn.Module): - """The layer that transforms the image model's output to part of the text model's input (namely, image features)""" - - def __init__(self, config: Kosmos2_5Config): - super().__init__() - self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim) - self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim)) - self.x_attn = KOSMOS2_5_TEXT_ATTENTION_CLASSES[config._attn_implementation]( - config.text_config, - config.text_config.embed_dim, - config.text_config.attention_heads, - dropout=config.text_config.attention_dropout, - is_decoder=False, - add_inner_attn_layernorm=False, - is_causal=False, - ) - # self.dropout = nn.Dropout(config.text_config.dropout, inplace=True) - - def forward(self, features): - hidden_states = self.dense(features) - - # shape = [batch, latent_query_num, h_dim] - latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1) - key_value_states = torch.cat([hidden_states, latent_query], dim=1) - - hidden_states, attn_weights, _ = self.x_attn( - hidden_states=latent_query, - encoder_hidden_states=key_value_states, - past_key_value=None, - attention_mask=None, - output_attentions=None, - ) - - return hidden_states, attn_weights - - -class Kosmos2_5PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = Kosmos2_5Config - supports_gradient_checkpointing = True - _no_split_modules = ["Kosmos2_5VisionEncoder", "Kosmos2_5TextBlock"] - _supports_flash_attn_2 = True - _supports_sdpa = True - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(self, Kosmos2_5VisionModel): - factor = self.config.initializer_factor - elif isinstance(self, (Kosmos2_5Model, Kosmos2_5ForConditionalGeneration)): - factor = self.config.vision_config.initializer_factor - - if isinstance(self, (Kosmos2_5TextModel, Kosmos2_5TextForCausalLM)): - std = self.config.init_std - elif isinstance(self, (Kosmos2_5Model, Kosmos2_5ForConditionalGeneration)): - std = self.config.text_config.init_std - - if isinstance(module, Kosmos2_5VisionEmbeddings): - nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) - nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) - nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) - elif isinstance(module, Kosmos2_5VisionAttention): - in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor - out_proj_std = (module.embed_dim**-0.5) * factor - nn.init.normal_(module.q_proj.weight, std=in_proj_std) - nn.init.normal_(module.k_proj.weight, std=in_proj_std) - nn.init.normal_(module.v_proj.weight, std=in_proj_std) - nn.init.normal_(module.out_proj.weight, std=out_proj_std) - if module.q_proj.bias is not None: - module.q_proj.bias.data.zero_() - if module.k_proj.bias is not None: - module.k_proj.bias.data.zero_() - if module.v_proj.bias is not None: - module.v_proj.bias.data.zero_() - if module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() - elif isinstance(module, Kosmos2_5VisionMlp): - in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor - fc_std = (2 * module.config.hidden_size) ** -0.5 * factor - nn.init.normal_(module.fc1.weight, std=fc_std) - nn.init.normal_(module.fc2.weight, std=in_proj_std) - if module.fc1.bias is not None: - module.fc1.bias.data.zero_() - if module.fc2.bias is not None: - module.fc2.bias.data.zero_() - elif isinstance(module, Kosmos2_5VisionLayer): - module.pre_layrnorm.bias.data.zero_() - module.pre_layrnorm.weight.data.fill_(1.0) - module.post_layernorm.bias.data.zero_() - module.post_layernorm.weight.data.fill_(1.0) - elif isinstance(module, Kosmos2_5TextAttention): - nn.init.normal_(module.q_proj.weight, std=std) - nn.init.normal_(module.k_proj.weight, std=std) - nn.init.normal_(module.v_proj.weight, std=std) - nn.init.normal_(module.out_proj.weight, std=std) - if module.q_proj.bias is not None: - module.q_proj.bias.data.zero_() - if module.k_proj.bias is not None: - module.k_proj.bias.data.zero_() - if module.v_proj.bias is not None: - module.v_proj.bias.data.zero_() - if module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() - elif isinstance(module, Kosmos2_5TextFFN): - nn.init.normal_(module.fc1.weight, std=std) - nn.init.normal_(module.fc2.weight, std=std) - if module.fc1.bias is not None: - module.fc1.bias.data.zero_() - if module.fc2.bias is not None: - module.fc2.bias.data.zero_() - elif isinstance(module, Kosmos2_5TextForCausalLM): - nn.init.normal_(module.lm_head.weight, std=std) - if module.lm_head.bias is not None: - module.lm_head.bias.data.zero_() - elif isinstance(module, Kosmos2_5ImageToTextProjection): - nn.init.normal_(module.dense.weight, std=std) - if module.dense.bias is not None: - module.dense.bias.data.zero_() - elif isinstance(module, Kosmos2_5TextTransformer): - module.embed_tokens.weight.data.normal_(mean=0.0, std=std) - if module.embed_tokens.padding_idx is not None: - module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_() - - -class Kosmos2_5TextModel(Kosmos2_5PreTrainedModel): - config_class = Kosmos2_5TextConfig - - def __init__(self, config: Kosmos2_5TextConfig): - super().__init__(config) - self.model = Kosmos2_5TextTransformer(config) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(KOSMOS2_5_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=Kosmos2_5TextConfig, - ) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, - image_embeds_position_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - r""" - Returns: - - """ - return self.model( - input_ids=input_ids, - attention_mask=attention_mask, - image_embeds=image_embeds, - image_embeds_position_mask=image_embeds_position_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -@add_start_docstrings( - """ - KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model. - """, - KOSMOS2_5_START_DOCSTRING, -) -class Kosmos2_5Model(Kosmos2_5PreTrainedModel): - config_class = Kosmos2_5Config - main_input_name = "flattened_patches" - - def __init__(self, config: Kosmos2_5Config): - super().__init__(config) - - self.text_model = Kosmos2_5TextModel(config.text_config) - self.vision_model = Kosmos2_5VisionModel(config.vision_config) - self.image_to_text_projection = Kosmos2_5ImageToTextProjection(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.text_model.model.embed_tokens - - def set_input_embeddings(self, value): - self.text_model.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(KOSMOS2_5_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Kosmos2_5ModelOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - flattened_patches: Optional[torch.Tensor] = None, - input_ids: Optional[torch.Tensor] = None, - image_embeds_position_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - image_embeds: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Kosmos2_5ModelOutput]: - r""" - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Kosmos2_5Model - - >>> model = Kosmos2_5Model.from_pretrained("microsoft/kosmos2.5") - >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos2.5") - - >>> url = "https://huggingface.co/microsoft/kosmos2.5/resolve/main/snowman.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = ( - ... " An image of a snowman" - ... " warming himself by a fire" - ... "" - ... ) - - >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True) - - >>> last_hidden_state = model( - ... pixel_values=inputs["pixel_values"], - ... input_ids=inputs["input_ids"], - ... attention_mask=inputs["attention_mask"], - ... image_embeds_position_mask=inputs["image_embeds_position_mask"], - ... ).last_hidden_state - >>> list(last_hidden_state.shape) - [1, 91, 2048] - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_model_output = None - projection_attentions = None - if image_embeds is None: - if flattened_patches is None: - raise ValueError("You have to specify either `flattened_patches` or `image_embeds`.") - - vision_model_output = self.vision_model( - flattened_patches=flattened_patches, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. - image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) - # normalized features - image_embeds = nn.functional.normalize(image_embeds, dim=-1) - image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) - - outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - image_embeds=image_embeds, - image_embeds_position_mask=image_embeds_position_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - outputs = outputs + ( - image_embeds, - projection_attentions, - vision_model_output, - ) - return tuple(output for output in outputs if output is not None) - - return Kosmos2_5ModelOutput( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_embeds=image_embeds, - projection_attentions=projection_attentions, - vision_model_output=vision_model_output, - ) - - -@add_start_docstrings( - """ - The text model from KOSMOS-2.5 with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - KOSMOS2_5_START_DOCSTRING, -) -class Kosmos2_5TextForCausalLM(Kosmos2_5PreTrainedModel): - config_class = Kosmos2_5TextConfig - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: Kosmos2_5TextConfig): - super().__init__(config) - - self.model = Kosmos2_5TextTransformer(config) - self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - @add_start_docstrings_to_model_forward(KOSMOS2_5_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=Kosmos2_5TextConfig) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, - image_embeds_position_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - - Returns: - - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - if use_cache: - logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") - use_cache = False - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - image_embeds=image_embeds, - image_embeds_position_mask=image_embeds_position_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - lm_logits = self.lm_head(outputs[0]) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length), - ) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - image_embeds=None, - image_embeds_position_mask=None, - past_key_values=None, - attention_mask=None, - use_cache=None, - **model_kwargs, - ): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - position_ids = None - - # cut input_ids if past_key_values is used - if past_key_values is not None: - position_ids = create_position_ids_from_input_ids( - input_ids, - padding_idx=self.config.pad_token_id, - past_key_values_length=0, - )[:, -1:] - - input_ids = input_ids[:, -1:] - # the image info. is already encoded into the past keys/values - image_embeds = None - image_embeds_position_mask = None - elif image_embeds_position_mask is not None: - # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) - batch_size, seq_len = input_ids.size() - mask_len = image_embeds_position_mask.size()[-1] - image_embeds_position_mask = torch.cat( - ( - image_embeds_position_mask, - torch.zeros( - size=(batch_size, seq_len - mask_len), - dtype=torch.bool, - device=input_ids.device, - ), - ), - dim=1, - ) - - return { - "input_ids": input_ids, - "image_embeds": image_embeds, - "image_embeds_position_mask": image_embeds_position_mask, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - "position_ids": position_ids, - "use_cache": use_cache, - } - - @staticmethod - # Copied from ...models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - KOSMOS-2.5 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a - language model. - """, - KOSMOS2_5_START_DOCSTRING, -) -class Kosmos2_5ForConditionalGeneration(Kosmos2_5PreTrainedModel): - config_class = Kosmos2_5Config - main_input_name = "flattened_patches" - _tied_weights_keys = ["text_model.lm_head.weight"] - - def __init__(self, config: Kosmos2_5Config): - super().__init__(config) - - self.text_model = Kosmos2_5TextForCausalLM(config.text_config) - self.vision_model = Kosmos2_5VisionModel(config.vision_config) - self.image_to_text_projection = Kosmos2_5ImageToTextProjection(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.text_model.model.embed_tokens - - def set_input_embeddings(self, value): - self.text_model.model.embed_tokens = value - - def get_output_embeddings(self) -> nn.Module: - return self.text_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.text_model.set_output_embeddings(new_embeddings) - - @add_start_docstrings_to_model_forward(KOSMOS2_5_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Kosmos2_5ForConditionalGenerationModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - flattened_patches: Optional[torch.Tensor] = None, - input_ids: Optional[torch.Tensor] = None, - image_embeds_position_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - image_embeds: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Kosmos2_5ForConditionalGenerationModelOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - - Returns: - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> import torch - >>> from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration - - >>> repo = "microsoft/kosmos-2.5" - >>> device = "cuda:0" - >>> dtype = torch.bfloat16 # torch.float16 - >>> model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype) - >>> processor = AutoProcessor.from_pretrained(repo) - - >>> url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" - - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> prompt = "" # - - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") - >>> height, width = inputs.pop("height"), inputs.pop("width") - >>> inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()} - >>> inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) - - >>> generated_ids = model.generate(**inputs,max_new_tokens=1024) - >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - >>> generated_text - '1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n' - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_model_output = None - projection_attentions = None - if image_embeds is None: - if flattened_patches is None: - raise ValueError("You have to specify either `flattened_patches` or `image_embeds`.") - - vision_model_output = self.vision_model( - flattened_patches=flattened_patches, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - image_embeds = vision_model_output.last_hidden_state - image_embeds = nn.functional.normalize(vision_model_output[0], dim=-1) - image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) - - lm_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - image_embeds=image_embeds, - image_embeds_position_mask=image_embeds_position_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - outputs = lm_outputs + ( - image_embeds, - projection_attentions, - vision_model_output, - ) - return tuple(output for output in outputs if output is not None) - - return Kosmos2_5ForConditionalGenerationModelOutput( - loss=lm_outputs.loss, - logits=lm_outputs.logits, - past_key_values=lm_outputs.past_key_values, - hidden_states=lm_outputs.hidden_states, - attentions=lm_outputs.attentions, - image_embeds=image_embeds, - projection_attentions=projection_attentions, - vision_model_output=vision_model_output, - ) - - def generate( - self, - flattened_patches: Optional[torch.Tensor] = None, - image_embeds_position_mask: Optional[torch.Tensor] = None, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - image_attention_mask: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, - **kwargs, - ): - # in order to allow `inputs` argument (as in `GenerationMixin`) - inputs = kwargs.pop("inputs", None) - if flattened_patches is not None and inputs is not None: - raise ValueError( - f"`inputs`: {inputs} were passed alongside `flattened_patches` which is not allowed." - f"Make sure to either pass `inputs` or flattened_patches=..." - ) - if flattened_patches is None and inputs is not None: - flattened_patches = inputs - - if image_embeds is None: - vision_model_output = self.vision_model( - flattened_patches=flattened_patches, - attention_mask=image_attention_mask, - output_hidden_states=True, - ) - - image_embeds = nn.functional.normalize(vision_model_output[0], dim=-1) - image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) - output = self.text_model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - image_embeds=image_embeds, - image_embeds_position_mask=image_embeds_position_mask, - **kwargs, - ) - - return output \ No newline at end of file +version https://git-lfs.github.com/spec/v1 +oid sha256:9bc21bb29f9c8667c089f9a6bf89ef30b13f29d050754a82c65ce88271403af9 +size 119785