import torch from einops import rearrange from torch import nn from typing import List, Optional, Tuple, Union from huggingface_hub import PyTorchModelHubMixin from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM from .helpers import PerceiverResampler from .vlm import VLMWithLanguageStream class AKI(VLMWithLanguageStream, PyTorchModelHubMixin): def __init__( self, vision_encoder_path: str, lang_model_path: str, pad_token_id: int, initial_tokenizer_len: Optional[int] = None, tokenizer: Optional[AutoTokenizer] = None, decoder_layers_attr_name: str = None, gradient_checkpointing: bool = False, base_img_size: Optional[int] = None, num_vision_tokens: int = 144, ): """ Args: vision_encoder (nn.Module): HF CLIPModel lang_encoder (nn.Module): HF causal language model vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder initial_tokenizer_len (int): size of the tokenizer vocab padding_token_id (int): id of the padding token. None if no padding token; then a padding token will be inserted into self.special_tokens, which factory.py fills after creating new tokens decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. """ # load the vision model model = AutoModel.from_pretrained(vision_encoder_path) vision_encoder = model.vision_model vis_feature_dim = vision_encoder.config.hidden_size # load the language model lang_model = AutoModelForCausalLM.from_pretrained( lang_model_path, local_files_only=False, trust_remote_code=True, ) self._special_tokens = { "media_token": "", "end_of_trunk_token": "<|endofchunk|>", } lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] super().__init__( vision_encoder=vision_encoder, vision_tokenizer=PerceiverResampler( dim=vis_feature_dim, dim_inner=lang_embedding_dim, num_latents=num_vision_tokens, ), lang_model=lang_model, initial_tokenizer_len=initial_tokenizer_len, gradient_checkpointing=gradient_checkpointing, base_img_size=base_img_size, decoder_layers_attr_name=decoder_layers_attr_name, pad_token_id=pad_token_id, ) if tokenizer is not None: self.lang_model.config.vocab_size = len(tokenizer) self.set_special_token_ids( { v: tokenizer.convert_tokens_to_ids(v) for v in self.special_tokens.values() } ) def set_trainable(self): """ Unfreeze everything except the vision_encoder """ self.requires_grad_(True) self.vision_encoder.requires_grad_(False) def forward( self, vision_x: Optional[torch.Tensor], lang_x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, past_key_values: Optional[ List[Union[torch.Tensor, Tuple[torch.Tensor]]] ] = None, past_media_locations: Optional[torch.Tensor] = None, past_vision_tokens: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, **kwargs, ): """ Args: vision_x: Vision input shape (B, T_img, F, C, H, W) with F=1 only F = 1 is supported (single-frame videos) if T_img > the number of media tokens in the corresponding input_ids (lang_x), only the first number of media tokens in lang_x are used lang_x: Language input ids, with media tokens denoting where visual media should be inserted. shape (B, T_txt) attention_mask: Attention mask. Defaults to None. labels: Labels. Defaults to None. shape (B, T_txt) past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None. list of length = number of decoder layers in the LM exact implementation depends on LM, see Hugging Face docs past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None. shape (B, T_txt) past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None. use_cache (Optional[bool], optional): Whether to use cache. Defaults to False. If True, includes key_values, media_locations, and vision_tokens in the output. """ assert not (past_vision_tokens is None) ^ ( past_media_locations is None ), "past_vision_tokens and past_media_locations must both be None or both be not None" # convert pixels to vision tokens vision_attention_mask = None if vision_x is not None: vision_tokens = self.vision_tokenizer(self._encode_vision_x(vision_x=vision_x)) else: vision_tokens = None # fuse the vision and language tokens new_inputs = self._prepare_inputs_for_forward( vision_tokens=vision_tokens, lang_x=lang_x, attention_mask=attention_mask, vision_attention_mask=vision_attention_mask, labels=labels, past_key_values=past_key_values, past_media_locations=past_media_locations, padding_side="right", past_vision_tokens=past_vision_tokens, ) output = self.lang_model( **new_inputs, use_cache=use_cache, past_key_values=past_key_values, **kwargs, ) # postforward hooks self._post_forward_hook() return output def generate( self, vision_x: torch.Tensor, lang_x: torch.Tensor, attention_mask: torch.Tensor = None, past_key_values: Optional[ List[Union[torch.Tensor, Tuple[torch.Tensor]]] ] = None, past_media_locations: Optional[torch.Tensor] = None, past_vision_tokens: Optional[torch.Tensor] = None, **kwargs, ): """ Generate text conditioned on vision and language inputs. Args: vision_x (torch.Tensor): Vision input shape (B, T_img, F, C, H, W) see documentation for forward lang_x (torch.Tensor): Language input shape (B, T_txt) attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. **kwargs: see generate documentation in Hugging Face CausalLM models. Returns: torch.Tensor: lang_x with generated tokens appended to it """ num_beams = kwargs.pop("num_beams", 1) # convert pixels to vision tokens vision_attention_mask = None if vision_x is not None: vision_tokens = self.vision_tokenizer(self._encode_vision_x(vision_x=vision_x)) else: vision_tokens = None # fuse the vision and language tokens new_inputs = self._prepare_inputs_for_forward( vision_tokens=vision_tokens, lang_x=lang_x, attention_mask=attention_mask, vision_attention_mask=vision_attention_mask, past_key_values=past_key_values, past_media_locations=past_media_locations, past_vision_tokens=past_vision_tokens, padding_side="left", num_beams=num_beams, ) # customize handling of position_ids since attention mask is already formulated as 4D if len(new_inputs["attention_mask"].shape) == 4: position_ids = new_inputs.get("position_ids", None) if position_ids is None: seq_length = new_inputs["inputs_embeds"].shape[1] position_ids = torch.arange(seq_length, dtype=torch.long, device=new_inputs["inputs_embeds"].device) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) new_inputs["position_ids"] = position_ids if past_key_values is not None: output = self.lang_model.generate( **new_inputs, past_key_values=past_key_values, num_beams=num_beams, use_cache=True, **kwargs, ) else: output = self.lang_model.generate( **new_inputs, num_beams=num_beams, use_cache=True, **kwargs, ) self._post_forward_hook() return output