|
import torch |
|
from einops import rearrange |
|
from torch import nn |
|
from typing import List, Optional, Tuple, Union |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM |
|
|
|
from .helpers import PerceiverResampler |
|
from .vlm import VLMWithLanguageStream |
|
|
|
class AKI(VLMWithLanguageStream, PyTorchModelHubMixin): |
|
def __init__( |
|
self, |
|
vision_encoder_path: str, |
|
lang_model_path: str, |
|
pad_token_id: int, |
|
initial_tokenizer_len: Optional[int] = None, |
|
tokenizer: Optional[AutoTokenizer] = None, |
|
decoder_layers_attr_name: str = None, |
|
gradient_checkpointing: bool = False, |
|
base_img_size: Optional[int] = None, |
|
num_vision_tokens: int = 144, |
|
): |
|
""" |
|
Args: |
|
vision_encoder (nn.Module): HF CLIPModel |
|
lang_encoder (nn.Module): HF causal language model |
|
vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder |
|
initial_tokenizer_len (int): size of the tokenizer vocab |
|
padding_token_id (int): id of the padding token. None if no padding token; then a padding token |
|
will be inserted into self.special_tokens, which factory.py fills after creating new tokens |
|
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. |
|
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. |
|
""" |
|
|
|
|
|
model = AutoModel.from_pretrained(vision_encoder_path) |
|
vision_encoder = model.vision_model |
|
vis_feature_dim = vision_encoder.config.hidden_size |
|
|
|
|
|
lang_model = AutoModelForCausalLM.from_pretrained( |
|
lang_model_path, |
|
local_files_only=False, |
|
trust_remote_code=True, |
|
) |
|
|
|
self._special_tokens = { |
|
"media_token": "<image>", |
|
"end_of_trunk_token": "<|endofchunk|>", |
|
} |
|
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] |
|
super().__init__( |
|
vision_encoder=vision_encoder, |
|
vision_tokenizer=PerceiverResampler( |
|
dim=vis_feature_dim, dim_inner=lang_embedding_dim, |
|
num_latents=num_vision_tokens, |
|
), |
|
lang_model=lang_model, |
|
initial_tokenizer_len=initial_tokenizer_len, |
|
gradient_checkpointing=gradient_checkpointing, |
|
base_img_size=base_img_size, |
|
decoder_layers_attr_name=decoder_layers_attr_name, |
|
pad_token_id=pad_token_id, |
|
) |
|
|
|
if tokenizer is not None: |
|
self.lang_model.config.vocab_size = len(tokenizer) |
|
self.set_special_token_ids( |
|
{ |
|
v: tokenizer.convert_tokens_to_ids(v) |
|
for v in self.special_tokens.values() |
|
} |
|
) |
|
|
|
def set_trainable(self): |
|
""" |
|
Unfreeze everything except the vision_encoder |
|
""" |
|
self.requires_grad_(True) |
|
self.vision_encoder.requires_grad_(False) |
|
|
|
def forward( |
|
self, |
|
vision_x: Optional[torch.Tensor], |
|
lang_x: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[ |
|
List[Union[torch.Tensor, Tuple[torch.Tensor]]] |
|
] = None, |
|
past_media_locations: Optional[torch.Tensor] = None, |
|
past_vision_tokens: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
): |
|
""" |
|
Args: |
|
vision_x: Vision input |
|
shape (B, T_img, F, C, H, W) with F=1 |
|
only F = 1 is supported (single-frame videos) |
|
if T_img > the number of media tokens in the corresponding input_ids (lang_x), |
|
only the first number of media tokens in lang_x are used |
|
lang_x: Language input ids, with media tokens denoting where |
|
visual media should be inserted. |
|
shape (B, T_txt) |
|
attention_mask: Attention mask. Defaults to None. |
|
labels: Labels. Defaults to None. |
|
shape (B, T_txt) |
|
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None. |
|
list of length = number of decoder layers in the LM |
|
exact implementation depends on LM, see Hugging Face docs |
|
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None. |
|
shape (B, T_txt) |
|
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None. |
|
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False. |
|
If True, includes key_values, media_locations, and vision_tokens in the output. |
|
""" |
|
assert not (past_vision_tokens is None) ^ ( |
|
past_media_locations is None |
|
), "past_vision_tokens and past_media_locations must both be None or both be not None" |
|
|
|
|
|
vision_attention_mask = None |
|
if vision_x is not None: |
|
vision_tokens = self.vision_tokenizer(self._encode_vision_x(vision_x=vision_x)) |
|
else: |
|
vision_tokens = None |
|
|
|
|
|
new_inputs = self._prepare_inputs_for_forward( |
|
vision_tokens=vision_tokens, |
|
lang_x=lang_x, |
|
attention_mask=attention_mask, |
|
vision_attention_mask=vision_attention_mask, |
|
labels=labels, |
|
past_key_values=past_key_values, |
|
past_media_locations=past_media_locations, |
|
padding_side="right", |
|
past_vision_tokens=past_vision_tokens, |
|
) |
|
output = self.lang_model( |
|
**new_inputs, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
**kwargs, |
|
) |
|
|
|
|
|
self._post_forward_hook() |
|
return output |
|
|
|
def generate( |
|
self, |
|
vision_x: torch.Tensor, |
|
lang_x: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
past_key_values: Optional[ |
|
List[Union[torch.Tensor, Tuple[torch.Tensor]]] |
|
] = None, |
|
past_media_locations: Optional[torch.Tensor] = None, |
|
past_vision_tokens: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Generate text conditioned on vision and language inputs. |
|
Args: |
|
vision_x (torch.Tensor): Vision input |
|
shape (B, T_img, F, C, H, W) |
|
see documentation for forward |
|
lang_x (torch.Tensor): Language input |
|
shape (B, T_txt) |
|
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. |
|
**kwargs: see generate documentation in Hugging Face CausalLM models. |
|
Returns: |
|
torch.Tensor: lang_x with generated tokens appended to it |
|
""" |
|
num_beams = kwargs.pop("num_beams", 1) |
|
|
|
|
|
vision_attention_mask = None |
|
if vision_x is not None: |
|
vision_tokens = self.vision_tokenizer(self._encode_vision_x(vision_x=vision_x)) |
|
else: |
|
vision_tokens = None |
|
|
|
|
|
new_inputs = self._prepare_inputs_for_forward( |
|
vision_tokens=vision_tokens, |
|
lang_x=lang_x, |
|
attention_mask=attention_mask, |
|
vision_attention_mask=vision_attention_mask, |
|
past_key_values=past_key_values, |
|
past_media_locations=past_media_locations, |
|
past_vision_tokens=past_vision_tokens, |
|
padding_side="left", |
|
num_beams=num_beams, |
|
) |
|
|
|
|
|
if len(new_inputs["attention_mask"].shape) == 4: |
|
position_ids = new_inputs.get("position_ids", None) |
|
if position_ids is None: |
|
seq_length = new_inputs["inputs_embeds"].shape[1] |
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=new_inputs["inputs_embeds"].device) |
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
|
new_inputs["position_ids"] = position_ids |
|
|
|
if past_key_values is not None: |
|
output = self.lang_model.generate( |
|
**new_inputs, |
|
past_key_values=past_key_values, |
|
num_beams=num_beams, |
|
use_cache=True, |
|
**kwargs, |
|
) |
|
else: |
|
output = self.lang_model.generate( |
|
**new_inputs, |
|
num_beams=num_beams, |
|
use_cache=True, |
|
**kwargs, |
|
) |
|
self._post_forward_hook() |
|
return output |
|
|
|
|