from attrdict import AttrDict from dataclasses import dataclass import logging import gc from einops import rearrange, repeat from typing import Optional, List, Tuple, Callable, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, ) from transformers.modeling_outputs import ModelOutput from transformers.configuration_utils import PretrainedConfig from transformers import ( AutoConfig, AutoModelForCausalLM, PreTrainedModel ) from transformers.utils import logging from .siglip_vit import VisionTransformer from .configuration_deepseek import DeepseekV2Config from .modeling_deepseek import DeepseekV2ForCausalLM logger = logging.get_logger(__name__) class MlpProjector(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg if cfg.projector_type == "identity": modules = nn.Identity() elif cfg.projector_type == "linear": modules = nn.Linear(cfg.input_dim, cfg.n_embed) elif cfg.projector_type == "mlp_gelu": mlp_depth = cfg.depth modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) modules = nn.Sequential(*modules) elif cfg.projector_type == "downsample_mlp_gelu": mlp_depth = cfg.depth mlp_ratio = cfg.mlp_ratio modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)] for _ in range(1, mlp_depth - 1): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) modules = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {cfg.projector_type}") if cfg.token_pooling: self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim) self.layers = modules def forward(self, x): if self.cfg.token_pooling: batch_size, wxh, channels = x.shape w = h = int(wxh ** 0.5) x = x.view(batch_size, w, h, channels) x = x.permute(0, 3, 1, 2) # import ipdb; ipdb.set_trace() patches = x.unfold(2, 2, 2).unfold(3, 2, 2) batch_size, channels, h_patches, w_patches, _, _ = patches.size() # 在通道维度上拼接 patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1) # 通过线性层 patches = patches.permute(0, 2, 1, 3).contiguous() patches = patches.view(batch_size, h_patches * w_patches, channels * 4) x = self.token_pooling_layer(patches) elif self.cfg.projector_type == 'downsample_mlp_gelu': bs, hw, input_dim = x.shape h = w = int((hw) ** 0.5) """compute padding""" if h % self.cfg.downsample_ratio: pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio else: pad = 0 x = x.reshape(bs, h, w, input_dim) if pad > 0: x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) """4 to 1 concat""" x = x.permute(0, 3, 1, 2) # B, C, H, W x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4 x = x.permute(0, 2, 1) return self.layers(x) class VisionEncoderConfig(PretrainedConfig): model_type: str = "vision" model_name: str = "siglip_large_patch16_384" image_size: int = 384 patch_size: int = 16 width: int = 1024 layers: int = 24 heads: int = 16 mlp_ratio: int = 4 global_pool: str = "map" ignore_head: bool = True class_token: bool = False num_classes: int = 0 use_checkpoint: bool = False weight_init: str = "skip" deterministic: bool = False num_recomputing_layers: int = 0 def __init__( self, model_name: str = "siglip_large_patch16_384", image_size: int = 384, patch_size: int = 16, width: int = 1024, layers: int = 24, heads: int = 16, mlp_ratio: int = 4, global_pool: str = "map", ignore_head: bool = True, class_token: bool = False, num_classes: int = 0, use_checkpoint: bool = False, **kwargs ): self.model_name = model_name self.image_size = image_size self.patch_size = patch_size self.width = width self.layers = layers self.heads = heads self.mlp_ratio = mlp_ratio self.global_pool = global_pool self.ignore_head = ignore_head self.class_token = class_token self.num_classes = num_classes self.use_checkpoint = use_checkpoint super().__init__(**kwargs) class MlpProjectorConfig(PretrainedConfig): model_type = "mlp_projector" projector_type: str = "downsample_mlp_gelu" input_dim: int = 1152 n_embed: int = 2048 depth: int = 2 mlp_ratio: int = 1 downsample_ratio: int = 2 token_pooling: bool = False def __init__( self, projector_type: str = "downsample_mlp_gelu", input_dim: int = 1152, n_embed: int = 2048, depth: int = 2, mlp_ratio: int = 1, downsample_ratio: int = 2, **kwargs ): self.projector_type = projector_type self.input_dim = input_dim self.n_embed = n_embed self.depth = depth self.mlp_ratio = mlp_ratio self.downsample_ratio = downsample_ratio super().__init__(**kwargs) @dataclass class DeepSeekVLV2CausalLMOutputWithPast(ModelOutput): """ Base class for DeepSeek-VL2 causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None class DeepseekVLV2Config(PretrainedConfig): model_type = "deepseek_vl_v2" vision_config: VisionEncoderConfig projector_config: MlpProjectorConfig language_config: DeepseekV2Config tile_tag: str = "2D" global_view_pos: str = "head" candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),) def __init__( self, tile_tag: str = "tile_tag", global_view_pos: str = "head", candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),), **kwargs ): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) self.vision_config = VisionEncoderConfig(**vision_config) projector_config = kwargs.get("projector_config", {}) self.projector_config = MlpProjectorConfig(**projector_config) language_config = kwargs.get("language_config", {}) if isinstance(language_config, DeepseekV2Config): self.language_config = language_config else: self.language_config = DeepseekV2Config(**language_config) self.tile_tag = tile_tag self.global_view_pos = global_view_pos self.candidate_resolutions = candidate_resolutions class DeepseekVLV2PreTrainedModel(PreTrainedModel): config_class = DeepseekVLV2Config base_model_prefix = "deepseek_vl_v2" _no_split_modules = [] _skip_keys_device_placement = "past_key_values" class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel): def __init__(self, config: DeepseekVLV2Config): super().__init__(config) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # ----------- vision encoder ------------ vision_config = config.vision_config self.vision = VisionTransformer( img_size=vision_config.image_size, patch_size=vision_config.patch_size, embed_dim=vision_config.width, depth=vision_config.layers, num_heads=vision_config.heads, mlp_ratio=vision_config.mlp_ratio, class_token=vision_config.class_token, global_pool=vision_config.global_pool, ignore_head=vision_config.ignore_head, weight_init=vision_config.weight_init, num_classes=0, deterministic=vision_config.deterministic, num_recomputing_layers=vision_config.num_recomputing_layers ) # ----------- vl projector ------------ projector_config = config.projector_config self.projector = MlpProjector(projector_config) # image token format 形式 # FIXME 目前tile tag & global_view_pos的默认取值都是之前的实验策略;后续应当去掉默认取值,改为没有取值就raise error self.tile_tag = config.tile_tag self.global_view_pos = config.global_view_pos # 用于format image token sequence的特殊token embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32)) if self.tile_tag == "2D": # <|view_separator|>, <|\n|> self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std) # fix the typo: view_seperater self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std) elif self.tile_tag == "1D": # <|tile_x|>, <|tile_global|> candidate_resolutions = config.candidate_resolutions if len(candidate_resolutions) == 0: raise ValueError( f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}") tile_variants_num = len(candidate_resolutions) self.tile_indicators = nn.Parameter( torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std ) else: raise ValueError(f"tile tag should be either 1D or 2D, but got {self.tile_tag}") # ----------- language model ------------ language_config = config.language_config self.language = DeepseekV2ForCausalLM(language_config) def prepare_inputs_embeds( self, input_ids: torch.LongTensor, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, **ignore_kwargs ): """ Args: input_ids (torch.LongTensor): [b, T] images (torch.FloatTensor): [b, max_n_images, 3, height, width] images_seq_mask (torch.BoolTensor): [b, T] images_spatial_crop (torch.LongTensor): [b, max_n_images, 2] Returns: input_embeds (torch.Tensor): [b, T, D] """ if images is None or images_spatial_crop.sum() == 0: return self.language.get_input_embeddings()(input_ids) bs, max_n_images, _ = images_spatial_crop.shape batch_num_tiles = [0 for _ in range(bs)] total_tiles = [] for idx in range(bs): for jdx in range(max_n_images): num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx] if num_width_tiles == 0 or num_height_tiles == 0: break batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles) total_tiles.append(images[idx, :batch_num_tiles[idx]]) # [batch_all_tiles, 3, height, width] total_tiles = torch.cat(total_tiles, dim=0) assert total_tiles.shape[0] == sum(batch_num_tiles) if total_tiles.shape[0] == 0: return self.language.get_input_embeddings()(input_ids) # [batch_all_tiles, vit_seq_len, c] images_feature = self.vision(total_tiles) # [batch_all_tiles, hw, D] images_embeds = self.projector(images_feature) _, hw, n_dim = images_embeds.shape h = w = int(hw ** 0.5) # put image tokens into the input_embeds, [b, T, D] input_embeds = self.language.get_input_embeddings()(input_ids) # 根据self.tile_tag & self.global_view_pos填充image token sequence tile_index = 0 for idx in range(images_spatial_crop.shape[0]): images_in_this_batch = [] for jdx in range(images_spatial_crop.shape[1]): # extra global & local features num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx] if num_width_tiles == 0 or num_height_tiles == 0: break num_tiles_in_image = num_width_tiles * num_height_tiles # [hw, D] global_features = images_embeds[tile_index] # [num_height_tiles * num_width_tiles, hw, D] local_features = images_embeds[tile_index + 1: tile_index + 1 + num_tiles_in_image] tile_index += num_tiles_in_image + 1 # format global and local features if self.tile_tag == "2D": # ----------------- global view add newline ----------------- # [hw, D] -> [h, w, D] global_features = global_features.view(h, w, n_dim) # [D] -> [h, 1, D] new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h) # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D] global_features = torch.cat([global_features, new_lines_in_global], dim=1) # [h, w + 1, D] -> [h * (w + 1), D] global_features = global_features.view(-1, n_dim) # ----------------- local view add newline ----------------- # [num_height_tiles * num_width_tiles, h * w, D] -> [num_height_tiles * h, num_width_tiles * w, D] local_features = rearrange( local_features, "(th tw) (h w) d -> (th h) (tw w) d", th=num_height_tiles, tw=num_width_tiles, h=h, w=w ) # [D] -> [num_height_tiles * h, 1, D] new_lines_in_local = repeat( self.image_newline, "d -> (th h) 1 d", th=num_height_tiles, h=h ) # [num_height_tiles * h, num_width_tiles * w + 1, D] local_features = torch.cat([local_features, new_lines_in_local], dim=1) # [num_height_tiles * h, num_width_tiles * w + 1, D] # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D] local_features = local_features.view(-1, n_dim) # ----------------- merge global and local tiles ----------------- if self.global_view_pos == "head": global_local_features = torch.cat( [global_features, self.view_seperator[None, :], local_features], dim=0) else: global_local_features = torch.cat( [local_features, self.view_seperator[None, :], global_features], dim=0) else: # abandoned,实际上不会走这个逻辑 global_features = torch.cat( [self.tile_indicators[0:1], global_features], dim=0 ) local_features = torch.cat( [self.tile_indicators[1:num_tiles_in_image + 1].unsqueeze(1), local_features], dim=1 ) local_features = rearrange(local_features, 'crop_num hw d -> (crop_num hw) d') if self.global_view_pos == "head": global_local_features = torch.cat([global_features, local_features], dim=0) else: global_local_features = torch.cat([local_features, global_features], dim=0) images_in_this_batch.append(global_local_features) if len(images_in_this_batch) > 0: images_in_this_batch = torch.cat(images_in_this_batch, dim=0) input_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch) return input_embeds @torch.no_grad() def incremental_prefilling( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, chunk_size: int = 1024 ): if inputs_embeds is None: inputs_embeds = self.prepare_inputs_embeds( input_ids=input_ids, images=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, ) del images del images_seq_mask del images_spatial_crop if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) self._clear_cuda_cache() bzs, seq_len, _ = inputs_embeds.shape past_key_values = None # remain the last token for the next forward prefilling_len = seq_len - 1 for i in range(0, prefilling_len, chunk_size): chunk_start = i chunk_end = min(i + chunk_size, prefilling_len) chunk_inputs_embeds = inputs_embeds[:, chunk_start: chunk_end] chunk_attention_mask = attention_mask[:, 0: chunk_end] # print(f"start = {chunk_start}, end = {chunk_end}, prefilling_len = {prefilling_len}, seq_len = {seq_len}") # compute position_ids if past_key_values is not None: position_ids = torch.arange( chunk_start, chunk_end, dtype=torch.long, device=inputs_embeds.device ).unsqueeze(0) past_key_values = self._move_past_key_values_to_gpu(past_key_values, inputs_embeds.device) else: position_ids = None # chunk-forward with torch.no_grad(): outputs = self.forward( inputs_embeds=chunk_inputs_embeds, attention_mask=chunk_attention_mask, past_key_values=past_key_values, position_ids=position_ids, use_cache=True, ) # update past_key_values past_key_values = outputs.past_key_values past_key_values = self._move_past_key_values_to_cpu(past_key_values) del outputs, position_ids self._clear_cuda_cache() prefilling_key_values = [] for layer_past in past_key_values: prefilling_key_values.append( ( layer_past[0][:, :, 0: prefilling_len, ...].to(inputs_embeds.device), layer_past[1][:, :, 0: prefilling_len, ...].to(inputs_embeds.device), ) ) return inputs_embeds, prefilling_key_values def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.prepare_inputs_embeds( input_ids=input_ids, images=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, ) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # print(inputs_embeds.shape) outputs = self.language.forward( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position ) return outputs def _clear_cuda_cache(self): """clear CUDA memory cache""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def _move_past_key_values_to_cpu(self, past_key_values): # print(f"past_key_values -> cpu") if past_key_values is None: return None return tuple(tuple(t.cpu() for t in layer) for layer in past_key_values) def _move_past_key_values_to_gpu(self, past_key_values, device="cuda:0"): # print(f"past_key_values -> gpu") if past_key_values is None: return None return tuple(tuple(t.to(device) for t in layer) for layer in past_key_values) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, attention_mask=None, cache_position=None, pixel_values=None, image_sizes=None, num_logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = self.language.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, **kwargs, ) # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model cache_position = model_inputs["cache_position"] if cache_position[0] == 0: model_inputs["images"] = images model_inputs["images_seq_mask"] = images_seq_mask model_inputs["images_spatial_crop"] = images_spatial_crop return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ), ) return reordered_past AutoConfig.register("vision", VisionEncoderConfig) AutoConfig.register("mlp_projector", MlpProjectorConfig) AutoConfig.register("deepseek_vl_v2", DeepseekVLV2Config) AutoModelForCausalLM.register(DeepseekVLV2Config, DeepseekVLV2ForCausalLM)