Spaces:
Runtime error
Runtime error
| from contextlib import nullcontext | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from transformers import PreTrainedModel | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache | |
| from transformers.modeling_outputs import ModelOutput | |
| from transformers.models.clip.configuration_clip import CLIPConfig | |
| from transformers.utils import ( | |
| add_start_docstrings, | |
| add_start_docstrings_to_model_forward, | |
| logging, | |
| replace_return_docstrings, | |
| ) | |
| from transformers import AutoModel, AutoModelForCausalLM | |
| from transformers.models.llava.configuration_llava import LlavaConfig | |
| from transformers.models.llava.modeling_llava import ( | |
| LlavaCausalLMOutputWithPast, | |
| LlavaMultiModalProjector, | |
| LlavaPreTrainedModel, | |
| LLAVA_START_DOCSTRING, | |
| LLAVA_INPUTS_DOCSTRING, | |
| LlavaForConditionalGeneration, | |
| ) | |
| from transformers.models.blip_2.configuration_blip_2 import ( | |
| Blip2Config, | |
| Blip2QFormerConfig, | |
| ) | |
| import os | |
| from transformers.models.blip_2.modeling_blip_2 import ( | |
| Blip2Config, | |
| Blip2QFormerModel, | |
| Blip2PreTrainedModel, | |
| BLIP_2_INPUTS_DOCSTRING, | |
| ) | |
| from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10 | |
| # from .configuration_sealmm import SeaLMMConfig | |
| logger = logging.get_logger(__name__) | |
| # _CONFIG_FOR_DOC = "LlavaConfig" | |
| _CONFIG_FOR_DOC = "SeaLMMConfig" | |
| class SeaLMMConfig(LlavaConfig): | |
| def __init__(self, *args, **kwargs): | |
| self.projector_num_layers = kwargs.get("projector_num_layers", 1) | |
| super().__init__(*args, **kwargs) | |
| """ | |
| Llava | |
| vision_config.num_hidden_layers = vision_config.num_hidden_layers + config.vision_feature_layer + 1 | |
| # "num_hidden_layers": 24, | |
| """ | |
| IMAGE_TOKEN = "<|image|>" | |
| DEBUG = bool(int(os.environ.get("DEBUG", "0"))) | |
| def by_sample_merge_input_ids_with_image_features( | |
| self, image_features, inputs_embeds, input_ids, attention_mask=None, position_ids=None | |
| ): | |
| """ | |
| input_ids: [tlen] | |
| input_embeds: [tlen, dt] | |
| img_embeds: [ilen, ifeat, di] | |
| e.g: | |
| input_ids: [ | |
| a b c d e f X g h i j k X l m | |
| ] | |
| img_embeds: [3, ifeat, id] # img_embeds has padding | |
| """ | |
| num_images, num_image_patches, embed_dim = image_features.shape | |
| sequence_length = input_ids.size(0) | |
| left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) | |
| assert not left_padding, f'should only use right padding' | |
| # 1. Create a mask to know where special image tokens are | |
| special_image_token_mask = input_ids == self.config.image_token_index | |
| num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | |
| # Compute the maximum embed dimension | |
| max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length | |
| from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig | |
| from transformers.models.clip.modeling_clip import ( | |
| contrastive_loss, | |
| clip_loss, | |
| CLIPVisionModelOutput, | |
| CLIPTextModelOutput, | |
| CLIPOutput, | |
| CLIPTextEmbeddings, | |
| CLIPVisionEmbeddings, | |
| CLIPAttention, | |
| CLIPMLP, | |
| CLIPEncoderLayer, | |
| CLIPPreTrainedModel, | |
| CLIPTextTransformer, | |
| CLIPTextModel, | |
| CLIPVisionTransformer, | |
| CLIPVisionModel, | |
| CLIPModel, | |
| CLIPEncoder, | |
| CLIPTextModelWithProjection, | |
| CLIPVisionModelWithProjection, | |
| CLIP_START_DOCSTRING, | |
| CLIP_TEXT_INPUTS_DOCSTRING, | |
| CLIP_VISION_INPUTS_DOCSTRING, | |
| CLIP_INPUTS_DOCSTRING, | |
| ) | |
| from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling | |
| # Copied from transformers.models.llama.modeling_llama._get_unpad_data | |
| def _get_unpad_data(attention_mask): | |
| import torch.nn.functional as F | |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| max_seqlen_in_batch = seqlens_in_batch.max().item() | |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | |
| return ( | |
| indices, | |
| cu_seqlens, | |
| max_seqlen_in_batch, | |
| ) | |
| class CLIPFlashAttention2(CLIPAttention): | |
| """ | |
| CLIP flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays | |
| untouched. The only required change would be on the forward pass where it needs to correctly call the public API of | |
| flash attention and deal with padding tokens in case the input contains any of them. | |
| """ | |
| def __init__(self, config, is_causal=False): | |
| super().__init__(config) | |
| self.is_causal = is_causal | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Input shape: Batch x Time x Channel""" | |
| if output_attentions: | |
| raise ValueError("CLIPFlashAttention2 does not support output_attentions") | |
| if self.is_causal and causal_attention_mask is None: | |
| raise ValueError("CLIPFlashAttention2 has causal=True but no causal_attention_mask provided") | |
| bsz, tgt_len, embed_dim = hidden_states.size() | |
| # [batch_size, tgt_len, embed_dim] | |
| query_states = self.q_proj(hidden_states) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| # [batch_size, tgt_len, embed_dim] -> [batch_size, tgt_len, num_heads, head_dim] | |
| query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() | |
| key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() | |
| value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() | |
| attn_output = self._flash_attention_forward( | |
| query_states=query_states, | |
| key_states=key_states, | |
| value_states=value_states, | |
| attention_mask=attention_mask, | |
| query_length=tgt_len, | |
| dropout=self.dropout, | |
| softmax_scale=self.scale, | |
| ) | |
| # [batch_size, tgt_len, num_heads, head_dim] -> [batch_size, tgt_len, embed_dim] | |
| attn_output = attn_output.view(bsz, tgt_len, embed_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, None | |
| # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward | |
| def _flash_attention_forward( | |
| self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None | |
| ) -> torch.Tensor: | |
| """ | |
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | |
| first unpad the input, then computes the attention scores and pad the final attention scores. | |
| Args: | |
| query_states (`torch.Tensor`): | |
| Input query states to be passed to Flash Attention API | |
| key_states (`torch.Tensor`): | |
| Input key states to be passed to Flash Attention API | |
| value_states (`torch.Tensor`): | |
| Input value states to be passed to Flash Attention API | |
| attention_mask (`torch.Tensor`): | |
| The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | |
| position of padding tokens and 1 for the position of non-padding tokens. | |
| dropout (`int`, *optional*): | |
| Attention dropout | |
| softmax_scale (`float`, *optional*): | |
| The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | |
| """ | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
| # Contains at least one padding token in the sequence | |
| if attention_mask is not None: | |
| batch_size = query_states.shape[0] | |
| query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | |
| query_states, key_states, value_states, attention_mask, query_length | |
| ) | |
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
| attn_output_unpad = flash_attn_varlen_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_in_batch_q, | |
| max_seqlen_k=max_seqlen_in_batch_k, | |
| dropout_p=dropout, | |
| softmax_scale=softmax_scale, | |
| causal=self.is_causal, | |
| ) | |
| attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | |
| else: | |
| attn_output = flash_attn_func( | |
| query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal | |
| ) | |
| return attn_output | |
| def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): | |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
| key_layer = index_first_axis( | |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
| ) | |
| value_layer = index_first_axis( | |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
| ) | |
| if query_length == kv_seq_len: | |
| query_layer = index_first_axis( | |
| query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k | |
| ) | |
| cu_seqlens_q = cu_seqlens_k | |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
| indices_q = indices_k | |
| elif query_length == 1: | |
| max_seqlen_in_batch_q = 1 | |
| # There is a memcpy here, that is very bad. | |
| cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) | |
| indices_q = cu_seqlens_q[:-1] | |
| query_layer = query_layer.squeeze(1) | |
| else: | |
| # The :q_len slice assumes right padding. | |
| attention_mask = attention_mask[:, :query_length] | |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | |
| return ( | |
| query_layer, | |
| key_layer, | |
| value_layer, | |
| indices_q, | |
| (cu_seqlens_q, cu_seqlens_k), | |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
| ) | |
| class SeaLMMCLIPEncoderLayer(CLIPEncoderLayer): | |
| def __init__(self, config: CLIPConfig): | |
| super(CLIPEncoderLayer, self).__init__() | |
| self.embed_dim = config.hidden_size | |
| # self.self_attn = LlavaCLIPFlashAttention(config) | |
| if is_flash_attn_greater_or_equal_2_10(): | |
| self.self_attn = CLIPFlashAttention2(config) | |
| else: | |
| self.self_attn = CLIPAttention(config) | |
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| self.mlp = CLIPMLP(config) | |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| class SeaLMMCLIPEncoder(CLIPEncoder): | |
| """ | |
| Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a | |
| [`CLIPEncoderLayer`]. | |
| Args: | |
| config: CLIPConfig | |
| """ | |
| def __init__(self, config: CLIPConfig): | |
| super(CLIPEncoder, self).__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList([SeaLMMCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| inputs_embeds, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutput]: | |
| r""" | |
| Args: | |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. | |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors | |
| than the model's internal embedding lookup matrix. | |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
| - 1 for tokens that are **not masked**, | |
| - 0 for tokens that are **masked**. | |
| [What are attention masks?](../glossary#attention-mask) | |
| causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Causal mask for the text model. Mask values selected in `[0, 1]`: | |
| - 1 for tokens that are **not masked**, | |
| - 0 for tokens that are **masked**. | |
| [What are attention masks?](../glossary#attention-mask) | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
| for more detail. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| output_hidden_states = False | |
| output_attentions = False | |
| # return_dict = False | |
| encoder_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| hidden_states = inputs_embeds | |
| for idx, encoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| # if self.gradient_checkpointing and self.training: | |
| # layer_outputs = self._gradient_checkpointing_func( | |
| # encoder_layer.__call__, | |
| # hidden_states, | |
| # attention_mask, | |
| # causal_attention_mask, | |
| # output_attentions, | |
| # ) | |
| # else: | |
| # ! enforce no checkpointing here | |
| layer_outputs = encoder_layer( | |
| hidden_states, | |
| attention_mask, | |
| causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_attentions = all_attentions + (layer_outputs[1],) | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if not return_dict: | |
| return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) | |
| return BaseModelOutput( | |
| last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions | |
| ) | |
| class SeaLMMVisionTransformer(nn.Module): | |
| def __init__(self, config: CLIPVisionConfig): | |
| super().__init__() | |
| self.config = config | |
| embed_dim = config.hidden_size | |
| self.embeddings = CLIPVisionEmbeddings(config) | |
| self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
| # self.encoder = CLIPEncoder(config) | |
| self.encoder = SeaLMMCLIPEncoder(config) | |
| # self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: | |
| r""" | |
| Returns: | |
| """ | |
| assert output_attentions is None | |
| assert output_hidden_states is None | |
| # assert return_dict is None | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if pixel_values is None: | |
| raise ValueError("You have to specify pixel_values") | |
| hidden_states = self.embeddings(pixel_values) | |
| hidden_states = self.pre_layrnorm(hidden_states) | |
| encoder_outputs = self.encoder( | |
| inputs_embeds=hidden_states, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| last_hidden_state = encoder_outputs[0] | |
| if not return_dict: | |
| raise ValueError(f'Not support return_dict') | |
| return BaseModelOutputWithPooling( | |
| last_hidden_state=last_hidden_state, | |
| # pooler_output=pooled_output, | |
| pooler_output=None, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |
| class SeaLMMCLIPVisionModel(CLIPPreTrainedModel): | |
| config_class = CLIPVisionConfig | |
| main_input_name = "pixel_values" | |
| _no_split_modules = ["SeaLMMCLIPEncoderLayer"] | |
| def __init__(self, config: CLIPVisionConfig): | |
| super().__init__(config) | |
| self.vision_model = SeaLMMVisionTransformer(config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self) -> nn.Module: | |
| return self.vision_model.embeddings.patch_embedding | |
| def forward( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: | |
| r""" | |
| Returns: | |
| Examples: | |
| ```python | |
| >>> from PIL import Image | |
| >>> import requests | |
| >>> from transformers import AutoProcessor, CLIPVisionModel | |
| >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") | |
| >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
| >>> image = Image.open(requests.get(url, stream=True).raw) | |
| >>> inputs = processor(images=image, return_tensors="pt") | |
| >>> outputs = model(**inputs) | |
| >>> last_hidden_state = outputs.last_hidden_state | |
| >>> pooled_output = outputs.pooler_output # pooled CLS states | |
| ```""" | |
| # return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| return self.vision_model( | |
| pixel_values=pixel_values, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| class SeaLMMMultiModalProjector(SeaLMMCLIPEncoder): | |
| def __init__(self, config: SeaLMMConfig): | |
| super(CLIPEncoder, self).__init__() | |
| self.config = config | |
| self.projector_num_layers = getattr(config, "projector_num_layers", 2) | |
| self.vision_config = config.vision_config | |
| self.num_vision_feature_layer = int(0 - config.vision_feature_layer) - 1 | |
| assert self.num_vision_feature_layer > 0 | |
| self.layers = nn.ModuleList([ | |
| # LlavaCLIPFasterEncoderLayer(self.vision_config) | |
| SeaLMMCLIPEncoderLayer(self.vision_config) | |
| for _ in range(self.projector_num_layers)] | |
| ) | |
| projector_layernorm_eps = getattr(config, "projector_layernorm_eps", 1e-05) | |
| self.projector_layernorm = nn.LayerNorm( | |
| # len(config.vision_feature_layers) * config.vision_config.hidden_size, eps=projector_layernorm_eps | |
| config.vision_config.hidden_size, eps=projector_layernorm_eps | |
| ) | |
| self.linear_1 = nn.Linear( | |
| # len(config.vision_feature_layers) * config.vision_config.hidden_size, | |
| config.vision_config.hidden_size, | |
| config.text_config.hidden_size, | |
| bias=True, | |
| ) | |
| # self.act = ACT2FN[config.projector_hidden_act] | |
| # self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) | |
| self.gradient_checkpointing = False | |
| def forward(self, hidden_states, attention_mask=None, causal_attention_mask=None): | |
| """ | |
| hidden_states must not be striped | |
| """ | |
| output_attentions = False | |
| for idx, encoder_layer in enumerate(self.layers): | |
| # if output_hidden_states: | |
| # encoder_states = encoder_states + (hidden_states,) | |
| # if self.gradient_checkpointing and self.training: | |
| # layer_outputs = self._gradient_checkpointing_func( | |
| # encoder_layer.__call__, | |
| # hidden_states, | |
| # attention_mask, | |
| # causal_attention_mask, | |
| # output_attentions, | |
| # ) | |
| # else: | |
| # ! turn off checkpointing | |
| layer_outputs = encoder_layer( | |
| hidden_states, | |
| attention_mask, | |
| causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| hidden_states = hidden_states[:, 1:] | |
| hidden_states = self.projector_layernorm(hidden_states) | |
| hidden_states = self.linear_1(hidden_states) | |
| # hidden_states = self.act(hidden_states) | |
| # hidden_states = self.linear_2(hidden_states) | |
| return hidden_states | |
| class SeaLMMForCausalLM(LlavaPreTrainedModel): | |
| def __init__(self, config: SeaLMMConfig, vision_tower=None, language_model=None): | |
| super().__init__(config) | |
| # self.vision_tower = AutoModel.from_config(config.vision_config) | |
| # self.vision_tower = vision_tower or LlavaCLIPVisionModel(config=config.vision_config) | |
| self.vision_tower = vision_tower or SeaLMMCLIPVisionModel(config=config.vision_config) | |
| self.multi_modal_projector = SeaLMMMultiModalProjector(config) | |
| # self.vocab_size = config.text_config.vocab_size | |
| self.language_model = language_model or AutoModelForCausalLM.from_config( | |
| config.text_config, attn_implementation=config._attn_implementation | |
| ) | |
| self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 | |
| self.post_init() | |
| self.freeze_vision_tower = True | |
| def unfreeze_vision_tower(self): | |
| logger.info(f'UNFREEZE {self.freeze_vision_tower=}') | |
| self.freeze_vision_tower = False | |
| def freeze_vision_tower(self): | |
| logger.info(f'FREEZE {self.freeze_vision_tower=}') | |
| self.freeze_vision_tower = True | |
| def create_model_config_from_components( | |
| cls, | |
| lm_config=None, | |
| vision_config=None, | |
| tokenizer=None, | |
| vision_feature_layer=None, | |
| projector_num_layers=1, | |
| **kwargs, | |
| ) -> SeaLMMConfig: | |
| # self.projector_num_layers = kwargs.get("projector_num_layers", 1) | |
| config = SeaLMMConfig(vision_config, lm_config, projector_num_layers=projector_num_layers, **kwargs) | |
| config.vision_feature_layer = config.vision_feature_layer if vision_feature_layer is None else vision_feature_layer | |
| if config.vision_feature_layer < 0: | |
| config.vision_config.num_hidden_layers = config.vision_config.num_hidden_layers + config.vision_feature_layer + 1 | |
| else: | |
| config.vision_config.num_hidden_layers = config.vision_feature_layer + 1 | |
| if IMAGE_TOKEN not in tokenizer.get_vocab(): | |
| tokenizer.add_special_tokens({"cls_token": IMAGE_TOKEN}) | |
| config.image_token_index = tokenizer.cls_token_id | |
| config.vocab_size = config.text_config.vocab_size | |
| config.architectures = ["SeaLMMForCausalLM"] | |
| return config | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| def get_output_embeddings(self): | |
| return self.language_model.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings): | |
| self.language_model.set_output_embeddings(new_embeddings) | |
| def set_decoder(self, decoder): | |
| self.language_model.set_decoder(decoder) | |
| def get_decoder(self): | |
| return self.language_model.get_decoder() | |
| def tie_weights(self): | |
| return self.language_model.tie_weights() | |
| def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: | |
| model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | |
| # update vocab size | |
| self.config.text_config.vocab_size = model_embeds.num_embeddings | |
| self.config.vocab_size = model_embeds.num_embeddings | |
| self.vocab_size = model_embeds.num_embeddings | |
| return model_embeds | |
| # @torch.no_grad | |
| def _merge_input_ids_with_image_features( | |
| self, image_features, inputs_embeds, input_ids, attention_mask, position_ids, labels=None | |
| ): | |
| """ | |
| input_ids: [b, tlen] | |
| input_embeds: [b, tlen, dt] | |
| image_features: [b, ilen, ifeat, di] | |
| labels: None or [b, tlen] --> must extend labels to input_ids, | |
| # in input_ids, there may be image_token_index, number of image_token_index <= ilen | |
| input_ids: [ | |
| a b c d e f X g h i j k X l m | |
| o p q r X s t u v _ _ _ _ _ _ | |
| ] | |
| input_ids should be: [ | |
| a b c d e f X X X X X g h i j k X X X X X l m | |
| o p q r X X X X X s t u v _ _ _ _ _ _ _ _ _ _ | |
| ] | |
| labels should be: [ | |
| a b c d e f _ _ _ _ _ g h i j k _ _ _ _ _ l m | |
| o p q r _ _ _ _ _ s t u v _ _ _ _ _ _ _ _ _ _ | |
| ] | |
| # mask replace image onto it | |
| # Use torch.vmap for simplicy | |
| def sample_merge(): | |
| input_ids: [tlen] | |
| input_embeds: [tlen, dt] | |
| img_embeds: [ilen, ifeat, di] | |
| e.g: | |
| input_ids: [ | |
| a b c d e f X g h i j k X l m | |
| ] | |
| img_embeds: [3, ifeat, id] # img_embeds has padding | |
| """ | |
| with torch.no_grad(): | |
| num_images, num_image_patches, embed_dim = image_features.shape | |
| batch_size, sequence_length = input_ids.shape | |
| # left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) | |
| left_padding = torch.any(attention_mask[:, 0] == 0) | |
| # assert not left_padding or batch_size == 1 | |
| # 1. Create a mask to know where special image tokens are | |
| special_image_token_mask = input_ids == self.config.image_token_index | |
| num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | |
| # Reserve for padding of num_images | |
| total_num_special_image_tokens = torch.sum(special_image_token_mask) | |
| assert total_num_special_image_tokens == num_images, f'{total_num_special_image_tokens=} != {num_images=} | {image_features.shape} {input_ids}' | |
| # Compute the maximum embed dimension | |
| max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length | |
| batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) | |
| # 2. Compute the positions where text should be written | |
| # Calculate new positions for text tokens in merged image-text sequence. | |
| # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. | |
| # `torch.cumsum` computes how each image token shifts subsequent text token positions. | |
| # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. | |
| new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 | |
| nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] | |
| if left_padding: | |
| new_token_positions += nb_image_pad[:, None] # offset for left padding | |
| text_to_overwrite = new_token_positions[batch_indices, non_image_indices] | |
| # 3. Create the full embedding, already padded to the maximum position | |
| final_embedding = torch.zeros( | |
| batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device | |
| ) | |
| final_attention_mask = torch.zeros( | |
| batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device | |
| ) | |
| final_labels = None | |
| if labels is not None: | |
| final_labels = torch.full_like(final_attention_mask, self.config.ignore_index).to(torch.long) | |
| # In case the Vision model or the Language model has been offloaded to CPU, we need to manually | |
| # set the corresponding tensors into their correct target device. | |
| target_device = inputs_embeds.device | |
| batch_indices, non_image_indices, text_to_overwrite = ( | |
| batch_indices.to(target_device), | |
| non_image_indices.to(target_device), | |
| text_to_overwrite.to(target_device), | |
| ) | |
| attention_mask = attention_mask.to(target_device) | |
| # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] | |
| # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features | |
| final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] | |
| final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] | |
| if labels is not None: | |
| final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] | |
| # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling | |
| image_to_overwrite = torch.all(final_embedding == 0, dim=-1) | |
| # image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) | |
| if left_padding: | |
| image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) | |
| else: | |
| val = torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim) < new_token_positions[:, -1:].to(target_device) | |
| image_to_overwrite &= val | |
| if image_to_overwrite.sum() != image_features.shape[:-1].numel(): | |
| raise ValueError( | |
| f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" | |
| f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." | |
| ) | |
| final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) | |
| final_attention_mask |= image_to_overwrite | |
| position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | |
| if not left_padding: | |
| # Making sure its the same | |
| seq_lens = final_attention_mask.sum(-1) | |
| for i, (mask, seq_len) in enumerate(zip(final_attention_mask, seq_lens)): | |
| # seq_len = mask.sum(-1) | |
| assert torch.all(mask[:seq_len] == 1), f'final 1 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}' | |
| assert torch.all(mask[seq_len:] == 0), f'final 0 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}' | |
| # if DEBUG: | |
| # print(f'final_attention_mask=\n{final_attention_mask.tolist()}') | |
| # print(f'text_to_overwrite=\n{text_to_overwrite.int().tolist()}') | |
| # print(f'image_to_overwrite=\n{image_to_overwrite.int().tolist()}') | |
| # print(f'position_ids=\n{position_ids.tolist()}') | |
| # print(f'labels=\n{labels.tolist()}') | |
| # print(f'final_labels=\n{final_labels.tolist()}') | |
| return final_embedding, final_attention_mask, position_ids, final_labels | |
| def extract_image_features(self, pixel_values, vision_feature_select_strategy=None): | |
| vision_feature_select_strategy = ( | |
| vision_feature_select_strategy | |
| if vision_feature_select_strategy is not None | |
| else self.config.vision_feature_select_strategy | |
| ) | |
| with (torch.no_grad() if self.freeze_vision_tower else nullcontext()): | |
| image_outputs = self.vision_tower(pixel_values) | |
| hiddent_states = image_outputs.last_hidden_state | |
| image_features = self.multi_modal_projector(hiddent_states) | |
| return image_features | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| pixel_values: torch.FloatTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| vision_feature_layer: Optional[int] = None, | |
| vision_feature_select_strategy: Optional[str] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: | |
| r""" | |
| Args: | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| Returns: | |
| Example: | |
| ```python | |
| >>> from PIL import Image | |
| >>> import requests | |
| >>> from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") | |
| >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") | |
| >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:" | |
| >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" | |
| >>> image = Image.open(requests.get(url, stream=True).raw) | |
| >>> inputs = processor(text=prompt, images=image, return_tensors="pt") | |
| >>> # Generate | |
| >>> generate_ids = model.generate(**inputs, max_length=30) | |
| >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner" | |
| ```""" | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| vision_feature_layer = ( | |
| vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer | |
| ) | |
| vision_feature_select_strategy = ( | |
| vision_feature_select_strategy | |
| if vision_feature_select_strategy is not None | |
| else self.config.vision_feature_select_strategy | |
| ) | |
| if inputs_embeds is None: | |
| # 1. Extra the input embeddings | |
| for_inputs_embeds_ids = input_ids.clone() | |
| for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0 | |
| # inputs_embeds = self.get_input_embeddings()(input_ids) | |
| inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids) | |
| # 2. Merge text and images | |
| if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0: | |
| num_images = pixel_values.size(0) | |
| batch_size, sequence_length = input_ids.shape | |
| special_image_token_mask = input_ids == self.config.image_token_index | |
| # Reserve for padding of num_images | |
| total_num_special_image_tokens = torch.sum(special_image_token_mask) | |
| assert num_images == total_num_special_image_tokens, ( | |
| f'{num_images} < {total_num_special_image_tokens} | {special_image_token_mask}' | |
| ) | |
| # pixel_values = pixel_values[:total_num_special_image_tokens] | |
| # image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) | |
| # with (torch.no_grad() if self.freeze_vision_tower else nullcontext()): | |
| # image_outputs = self.vision_tower(pixel_values) | |
| # # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. | |
| # # selected_image_feature = image_outputs.hidden_states[vision_feature_layer] | |
| # selected_image_feature = image_outputs.last_hidden_state | |
| # if vision_feature_select_strategy == "default": | |
| # selected_image_feature = selected_image_feature[:, 1:] | |
| # elif vision_feature_select_strategy == "full": | |
| # selected_image_feature = selected_image_feature | |
| # else: | |
| # raise ValueError( | |
| # f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" | |
| # ) | |
| # image_features = self.multi_modal_projector(selected_image_feature) | |
| # print(f"{pixel_values.size()=}") | |
| # ! extract_image_features will handle all image features extraction | |
| image_features = self.extract_image_features(pixel_values) | |
| # if DEBUG: | |
| # image_features = image_features[:, :3] | |
| inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features( | |
| image_features, inputs_embeds, input_ids, attention_mask, position_ids, | |
| labels=labels | |
| ) | |
| # if labels is None: | |
| # # ! this is wrong! | |
| # labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) | |
| # print(inputs_embeds.size()) | |
| elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0: | |
| # there is no images | |
| pass | |
| else: | |
| # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of | |
| # generation with cache | |
| # ! (phi) why do we need to do this? | |
| # if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: | |
| # # ! it can possible the bug because if mistral, from the first layer_key like this | |
| # # ! MUST UNDERSTAND and fix error | |
| # # Retrieve the first layer to inspect the logits and mask out the hidden states | |
| # # that are set to 0 | |
| # first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0] | |
| # batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0) | |
| # # Get the target length | |
| # target_seqlen = first_layer_past_key_value.shape[-1] + 1 | |
| # extended_attention_mask = torch.ones( | |
| # (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), | |
| # dtype=attention_mask.dtype, | |
| # device=attention_mask.device, | |
| # ) | |
| # # print(f'{extended_attention_mask.shape} | {batch_index=} | {non_attended_tokens=}') | |
| # # Zero-out the places where we don't need to attend | |
| # extended_attention_mask[batch_index, non_attended_tokens] = 0 | |
| # attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) | |
| # position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | |
| # ! fix: https://github.com/huggingface/transformers/blob/c90268de7560c3fef21a927e0bfcf2b611a8711e/src/transformers/models/llava/modeling_llava.py | |
| # https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 | |
| if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: | |
| # Retrieve the first layer to inspect the logits and mask out the hidden states | |
| # that are set to 0 | |
| first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] | |
| # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 | |
| batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) | |
| # Get the target length | |
| target_seqlen = first_layer_past_key_value.shape[-1] + 1 | |
| extended_attention_mask = torch.ones( | |
| (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device, | |
| ) | |
| # Filter out only the tokens that can be un-attended, this can happen | |
| # in the case one uses Llava + Fused modules where the cache on the | |
| # first iteration is already big enough, or if one passes custom cache | |
| valid_indices = non_attended_tokens < extended_attention_mask.size(-1) | |
| new_batch_index = batch_index[valid_indices] | |
| new_non_attended_tokens = non_attended_tokens[valid_indices] | |
| # Zero-out the places where we don't need to attend | |
| extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 | |
| attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) | |
| position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | |
| outputs = self.language_model( | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| logits = outputs[0] | |
| loss = None | |
| if labels is not None: | |
| # Shift so that tokens < n predict n | |
| if attention_mask is not None: | |
| shift_attention_mask = attention_mask[..., 1:] | |
| shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() | |
| shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() | |
| else: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct( | |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) | |
| ) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return LlavaCausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs | |
| ): | |
| if past_key_values is not None: | |
| if isinstance(past_key_values, Cache): | |
| cache_length = past_key_values.get_seq_length() | |
| past_length = past_key_values.seen_tokens | |
| else: | |
| cache_length = past_length = past_key_values[0][0].shape[2] | |
| # Keep only the unprocessed tokens: | |
| # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
| # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as | |
| # input) | |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
| # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
| # input_ids based on the past_length. | |
| elif past_length < input_ids.shape[1]: | |
| input_ids = input_ids[:, past_length:] | |
| # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
| elif self.config.image_token_index in input_ids: | |
| input_ids = input_ids[:, input_ids.shape[1] - 1 :] | |
| # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the | |
| # older attention values, as their corresponding values are not part of the input. | |
| if cache_length < past_length and attention_mask is not None: | |
| attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] | |
| position_ids = kwargs.get("position_ids", None) | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and past_key_values is None: | |
| model_inputs = {"inputs_embeds": inputs_embeds} | |
| else: | |
| model_inputs = {"input_ids": input_ids} | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "attention_mask": attention_mask, | |
| "pixel_values": pixel_values, | |
| } | |
| ) | |
| return model_inputs | |
| def _reorder_cache(self, *args, **kwargs): | |
| return self.language_model._reorder_cache(*args, **kwargs) | |