from dataclasses import dataclass import logging import gc from einops import rearrange, repeat from typing import Optional, List, Tuple, Callable, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, ) from transformers.modeling_outputs import ModelOutput from transformers.configuration_utils import PretrainedConfig from transformers import ( AutoConfig, AutoModelForCausalLM, PreTrainedModel ) from transformers.utils import logging from .siglip_vit import VisionTransformer from .configuration_deepseek import DeepseekV2Config from .modeling_deepseek import DeepseekV2ForCausalLM logger = logging.get_logger(__name__) class MlpProjector(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg if cfg.projector_type == "identity": modules = nn.Identity() elif cfg.projector_type == "linear": modules = nn.Linear(cfg.input_dim, cfg.n_embed) elif cfg.projector_type == "mlp_gelu": mlp_depth = cfg.depth modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) modules = nn.Sequential(*modules) elif cfg.projector_type == "downsample_mlp_gelu": mlp_depth = cfg.depth mlp_ratio = cfg.mlp_ratio modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)] for _ in range(1, mlp_depth - 1): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) modules = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {cfg.projector_type}") if cfg.token_pooling: self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim) self.layers = modules def forward(self, x): if self.cfg.token_pooling: batch_size, wxh, channels = x.shape w = h = int(wxh ** 0.5) x = x.view(batch_size, w, h, channels) x = x.permute(0, 3, 1, 2) # import ipdb; ipdb.set_trace() patches = x.unfold(2, 2, 2).unfold(3, 2, 2) batch_size, channels, h_patches, w_patches, _, _ = patches.size() # 在通道维度上拼接 patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1) # 通过线性层 patches = patches.permute(0, 2, 1, 3).contiguous() patches = patches.view(batch_size, h_patches * w_patches, channels * 4) x = self.token_pooling_layer(patches) elif self.cfg.projector_type == 'downsample_mlp_gelu': bs, hw, input_dim = x.shape h = w = int((hw) ** 0.5) """compute padding""" if h % self.cfg.downsample_ratio: pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio else: pad = 0 x = x.reshape(bs, h, w, input_dim) if pad > 0: x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) """4 to 1 concat""" x = x.permute(0, 3, 1, 2) # B, C, H, W x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4 x = x.permute(0, 2, 1) return self.layers(x) class VisionEncoderConfig(PretrainedConfig): model_type: str = "vision" model_name: str = "siglip_large_patch16_384" image_size: int = 384 patch_size: int = 16 width: int = 1024 layers: int = 24 heads: int = 16 mlp_ratio: int = 4 global_pool: str = "map" ignore_head: bool = True class_token: bool = False num_classes: int = 0 use_checkpoint: bool = False weight_init: str = "skip" deterministic: bool = False num_recomputing_layers: int = 0 def __init__( self, model_name: str = "siglip_large_patch16_384", image_size: int = 384, patch_size: int = 16, width: int = 1024, layers: int = 24, heads: int = 16, mlp_ratio: int = 4, global_pool: str = "map", ignore_head: bool = True, class_token: bool = False, num_classes: int = 0, use_checkpoint: bool = False, **kwargs ): self.model_name = model_name self.image_size = image_size self.patch_size = patch_size self.width = width self.layers = layers self.heads = heads self.mlp_ratio = mlp_ratio self.global_pool = global_pool self.ignore_head = ignore_head self.class_token = class_token self.num_classes = num_classes self.use_checkpoint = use_checkpoint super().__init__(**kwargs) class MlpProjectorConfig(PretrainedConfig): model_type = "mlp_projector" projector_type: str = "downsample_mlp_gelu" input_dim: int = 1152 n_embed: int = 2048 depth: int = 2 mlp_ratio: int = 1 downsample_ratio: int = 2 token_pooling: bool = False def __init__( self, projector_type: str = "downsample_mlp_gelu", input_dim: int = 1152, n_embed: int = 2048, depth: int = 2, mlp_ratio: int = 1, downsample_ratio: int = 2, **kwargs ): self.projector_type = projector_type self.input_dim = input_dim self.n_embed = n_embed self.depth = depth self.mlp_ratio = mlp_ratio self.downsample_ratio = downsample_ratio super().__init__(**kwargs) @dataclass class DeepSeekVLV2CausalLMOutputWithPast(ModelOutput): """ Base class for DeepSeek-VL2 causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None class DeepseekVLV2Config(PretrainedConfig): model_type = "deepseek_vl_v2" vision_config: VisionEncoderConfig projector_config: MlpProjectorConfig language_config: DeepseekV2Config tile_tag: str = "2D" global_view_pos: str = "head" candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),) def __init__( self, tile_tag: str = "tile_tag", global_view_pos: str = "head", candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),), **kwargs ): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) self.vision_config = VisionEncoderConfig(**vision_config) projector_config = kwargs.get("projector_config", {}) self.projector_config = MlpProjectorConfig(**projector_config) language_config = kwargs.get("language_config", {}) if isinstance(language_config, DeepseekV2Config): self.language_config = language_config else: self.language_config = DeepseekV2Config(**language_config) self.tile_tag = tile_tag self.global_view_pos = global_view_pos self.candidate_resolutions = candidate_resolutions class DeepseekVLV2PreTrainedModel(PreTrainedModel): config_class = DeepseekVLV2Config base_model_prefix = "deepseek_vl_v2" _no_split_modules = [] _skip_keys_device_placement = "past_key_values" class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel): def __init__(self, config: DeepseekVLV2Config): super().__init__(config) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # ----------- vision encoder ------------ vision_config = config.vision_config self.vision = VisionTransformer( img_size=vision_config.image_size, patch_size=vision_config.patch_size, embed_dim=vision_config.width, depth=vision_config.layers, num_heads=vision_config.heads, mlp_ratio=vision_config.mlp_ratio, class_token=vision_config.class_token, global_pool=vision_config.global_pool, ignore_head=vision_config.ignore_head, weight_init=vision_config.weight_init, num_classes=0, deterministic=vision_config.deterministic, num_recomputing_layers=vision_config.num_recomputing_layers ) # ----------- vl projector ------------ projector_config = config.projector_config self.projector = MlpProjector(projector_config) # image token format 形式 # FIXME 目前tile tag & global_view_pos的默认取值都是之前的实验策略;后续应当去掉默认取值,改为没有取值就raise error self.tile_tag = config.tile_tag self.global_view_pos = config.global_view_pos # 用于format image token sequence的特殊token embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32)) if self.tile_tag == "2D": # <|view_separator|>, <|\n|> self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std) # fix the typo: view_seperater self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std) elif self.tile_tag == "1D": # <|tile_x|>, <|tile_global|> candidate_resolutions = config.candidate_resolutions if len(candidate_resolutions) == 0: raise ValueError( f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}") tile_variants_num = len(candidate_resolutions) self.tile_indicators = nn.Parameter( torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std ) else: raise ValueError(f"tile tag should be either 1D or 2D, but got {self.tile_tag}") # ----------- language model ------------ language_config = config.language_config self.language = DeepseekV2ForCausalLM(language_config) def prepare_inputs_embeds( self, input_ids: torch.LongTensor, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, **ignore_kwargs ): """ Args: input_ids (torch.LongTensor): [b, T] images (torch.FloatTensor): [b, max_n_images, 3, height, width] images_seq_mask (torch.BoolTensor): [b, T] images_spatial_crop (torch.LongTensor): [b, max_n_images, 2] Returns: input_embeds (torch.Tensor): [b, T, D] """ if images is None or images_spatial_crop.sum() == 0: return self.language.get_input_embeddings()(input_ids) bs, max_n_images, _ = images_spatial_crop.shape batch_num_tiles = [0 for _ in range(bs)] total_tiles = [] for idx in range(bs): for jdx in range(max_n_images): num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx] if num_width_tiles == 0 or num_height_tiles == 0: break batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles) total_tiles.append(images[idx, :batch_num_tiles[idx]]) # [batch_all_tiles, 3, height, width] total_tiles = torch.cat(total_tiles, dim=0) assert total_tiles.shape[0] == sum(batch_num_tiles) if total_tiles.shape[0] == 0: return self.language.get_input_embeddings()(input_ids) # [batch_all_tiles, vit_seq_len, c] images_feature = self.vision(total_tiles) # [batch_all_tiles, hw, D] images_embeds = self.projector(images_feature) _, hw, n_dim = images_embeds.shape h = w = int(hw ** 0.5) # put image tokens into the input_embeds, [b, T, D] input_embeds = self.language.get_input_embeddings()(input_ids) # 根据self.tile_tag & self.global_view_pos填充image token sequence tile_index = 0 for idx in range(images_spatial_crop.shape[0]): images_in_this_batch = [] for jdx in range(images_spatial_crop.shape[1]): # extra global & local features num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx] if num_width_tiles == 0 or num_height_tiles == 0: break num_tiles_in_image = num_width_tiles * num_height_tiles # [hw, D] global_features = images_embeds[tile_index] # [num_height_tiles * num_width_tiles, hw, D] local_features = images_embeds[tile_index + 1: tile_index + 1 + num_tiles_in_image] tile_index += num_tiles_in_image + 1 # format global and local features if self.tile_tag == "2D": # ----------------- global view add newline ----------------- # [hw, D] -> [h, w, D] global_features = global_features.view(h, w, n_dim) # [D] -> [h, 1, D] new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h) # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D] global_features = torch.cat([global_features, new_lines_in_global], dim=1) # [h, w + 1, D] -> [h * (w + 1), D] global_features = global_features.view(-1, n_dim) # ----------------- local view add newline ----------------- # [num_height_tiles * num_width_tiles, h * w, D] -> [num_height_tiles * h, num_width_tiles * w, D] local_features = rearrange( local_features, "(th tw) (h w) d -> (th h) (tw w) d", th=num_height_tiles, tw=num_width_tiles, h=h, w=w ) # [D] -> [num_height_tiles * h, 1, D] new_lines_in_local = repeat( self.image_newline, "d -> (th h) 1 d", th=num_height_tiles, h=h ) # [num_height_tiles * h, num_width_tiles * w + 1, D] local_features = torch.cat([local_features, new_lines_in_local], dim=1) # [num_height_tiles * h, num_width_tiles * w + 1, D] # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D] local_features = local_features.view(-1, n_dim) # ----------------- merge global and local tiles ----------------- if self.global_view_pos == "head": global_local_features = torch.cat( [global_features, self.view_seperator[None, :], local_features], dim=0) else: global_local_features = torch.cat( [local_features, self.view_seperator[None, :], global_features], dim=0) else: # abandoned,实际上不会走这个逻辑 global_features = torch.cat( [self.tile_indicators[0:1], global_features], dim=0 ) local_features = torch.cat( [self.tile_indicators[1:num_tiles_in_image + 1].unsqueeze(1), local_features], dim=1 ) local_features = rearrange(local_features, 'crop_num hw d -> (crop_num hw) d') if self.global_view_pos == "head": global_local_features = torch.cat([global_features, local_features], dim=0) else: global_local_features = torch.cat([local_features, global_features], dim=0) images_in_this_batch.append(global_local_features) if len(images_in_this_batch) > 0: images_in_this_batch = torch.cat(images_in_this_batch, dim=0) input_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch) return input_embeds @torch.no_grad() def incremental_prefilling( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, chunk_size: int = 1024 ): if inputs_embeds is None: inputs_embeds = self.prepare_inputs_embeds( input_ids=input_ids, images=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, ) del images del images_seq_mask del images_spatial_crop if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) self._clear_cuda_cache() bzs, seq_len, _ = inputs_embeds.shape past_key_values = None # remain the last token for the next forward prefilling_len = seq_len - 1 for i in range(0, prefilling_len, chunk_size): chunk_start = i chunk_end = min(i + chunk_size, prefilling_len) chunk_inputs_embeds = inputs_embeds[:, chunk_start: chunk_end] chunk_attention_mask = attention_mask[:, 0: chunk_end] # print(f"start = {chunk_start}, end = {chunk_end}, prefilling_len = {prefilling_len}, seq_len = {seq_len}") # compute position_ids if past_key_values is not None: position_ids = torch.arange( chunk_start, chunk_end, dtype=torch.long, device=inputs_embeds.device ).unsqueeze(0) past_key_values = self._move_past_key_values_to_gpu(past_key_values, inputs_embeds.device) else: position_ids = None # chunk-forward with torch.no_grad(): outputs = self.forward( inputs_embeds=chunk_inputs_embeds, attention_mask=chunk_attention_mask, past_key_values=past_key_values, position_ids=position_ids, use_cache=True, ) # update past_key_values past_key_values = outputs.past_key_values past_key_values = self._move_past_key_values_to_cpu(past_key_values) del outputs, position_ids self._clear_cuda_cache() prefilling_key_values = [] for layer_past in past_key_values: prefilling_key_values.append( ( layer_past[0][:, :, 0: prefilling_len, ...].to(inputs_embeds.device), layer_past[1][:, :, 0: prefilling_len, ...].to(inputs_embeds.device), ) ) return inputs_embeds, prefilling_key_values def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.prepare_inputs_embeds( input_ids=input_ids, images=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, ) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # print(inputs_embeds.shape) outputs = self.language.forward( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position ) return outputs def _clear_cuda_cache(self): """clear CUDA memory cache""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def _move_past_key_values_to_cpu(self, past_key_values): # print(f"past_key_values -> cpu") if past_key_values is None: return None return tuple(tuple(t.cpu() for t in layer) for layer in past_key_values) def _move_past_key_values_to_gpu(self, past_key_values, device="cuda:0"): # print(f"past_key_values -> gpu") if past_key_values is None: return None return tuple(tuple(t.to(device) for t in layer) for layer in past_key_values) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, attention_mask=None, cache_position=None, pixel_values=None, image_sizes=None, num_logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = self.language.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, **kwargs, ) # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model cache_position = model_inputs["cache_position"] if cache_position[0] == 0: model_inputs["images"] = images model_inputs["images_seq_mask"] = images_seq_mask model_inputs["images_spatial_crop"] = images_spatial_crop return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ), ) return reordered_past AutoConfig.register("vision", VisionEncoderConfig) AutoConfig.register("mlp_projector", MlpProjectorConfig) AutoConfig.register("deepseek_vl_v2", DeepseekVLV2Config) AutoModelForCausalLM.register(DeepseekVLV2Config, DeepseekVLV2ForCausalLM)