|
import torch |
|
from einops import rearrange |
|
from torch import nn |
|
from typing import List, Optional, Tuple, Union |
|
from .utils import extend_instance, stack_with_padding, stack_with_padding_2D_attention, num_params, getattr_recursive |
|
from .helpers import DecoupledEmbedding, DecoupledLinear, VLMOutputWithPast |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers import CLIPVisionModel |
|
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer |
|
|
|
|
|
class VLM(nn.Module): |
|
""" |
|
Generic vision-language model (VLM) class. |
|
A VLM consists of four components: |
|
1. A vision encoder that extracts features from pixels, e.g. CLIP |
|
input: (B, T_img, F, C, H, W) |
|
output: (B, T_img, F, v, d) |
|
2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head |
|
input: (B, T_img, F, v, d) |
|
output: (B, T_img, n, d) |
|
3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence |
|
4. A language model |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vision_encoder: nn.Module, |
|
vision_tokenizer: nn.Module, |
|
lang_model: nn.Module, |
|
initial_tokenizer_len: int, |
|
pad_token_id: int, |
|
gradient_checkpointing: bool = False, |
|
base_img_size: Optional[int] = None, |
|
): |
|
""" |
|
Args: |
|
vision_encoder (nn.Module): e.g. CLIP |
|
vision_tokenizer (nn.Module): e.g. PerceiverResampler |
|
lang_model (nn.Module): e.g. MPT |
|
initial_tokenizer_len (int): size of the original tokenizer vocab |
|
pad_token_id (int): id of the pad token |
|
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. |
|
""" |
|
super().__init__() |
|
|
|
|
|
self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] |
|
if hasattr(lang_model.config, "d_model"): |
|
self.lang_hidden_dim = lang_model.config.d_model |
|
else: |
|
self.lang_hidden_dim = lang_model.config.hidden_size |
|
self.vis_embedding_dim = vision_tokenizer.dim_media |
|
self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media |
|
|
|
|
|
self.vision_encoder = vision_encoder |
|
self.vision_tokenizer = vision_tokenizer |
|
self.lang_model = lang_model |
|
|
|
if base_img_size is None: |
|
if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer): |
|
base_img_size = self.vision_encoder.config.image_size |
|
else: |
|
base_img_size = self.vision_encoder.image_size[0] |
|
self.base_img_size = base_img_size |
|
|
|
|
|
self.pad_token_id = pad_token_id |
|
self.initial_tokenizer_len = initial_tokenizer_len |
|
input_embeds = DecoupledEmbedding( |
|
max_original_id=initial_tokenizer_len - 1, |
|
num_additional_embeddings=len(self.special_tokens), |
|
_weight=self.lang_model.get_input_embeddings().weight, |
|
pad_token_id=self.pad_token_id, |
|
) |
|
if hasattr(input_embeds, "additional_embedding"): |
|
input_embeds.additional_embedding.weight.data.normal_( |
|
mean=0.0, |
|
std=self.lang_model.config.initializer_range |
|
if hasattr(self.lang_model.config, "initializer_range") |
|
else 0.02, |
|
) |
|
self.lang_model.set_input_embeddings(input_embeds) |
|
|
|
out_embeds = DecoupledLinear( |
|
max_original_id=initial_tokenizer_len - 1, |
|
additional_out_features=len(self.special_tokens), |
|
_weight=self.lang_model.get_output_embeddings().weight, |
|
_bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None, |
|
) |
|
if hasattr(out_embeds, "additional_fc"): |
|
out_embeds.additional_fc.weight.data.normal_( |
|
mean=0.0, |
|
std=self.lang_model.config.initializer_range |
|
if hasattr(self.lang_model.config, "initializer_range") |
|
else 0.02, |
|
) |
|
self.lang_model.set_output_embeddings(out_embeds) |
|
|
|
|
|
self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing |
|
|
|
def forward( |
|
self, |
|
vision_x: Optional[torch.Tensor], |
|
lang_x: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[ |
|
List[Union[torch.Tensor, Tuple[torch.Tensor]]] |
|
] = None, |
|
past_media_locations: Optional[torch.Tensor] = None, |
|
past_vision_tokens: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
): |
|
""" |
|
Args: |
|
vision_x: Vision input |
|
shape (B, T_img, F, C, H, W) with F=1 |
|
only F = 1 is supported (single-frame videos) |
|
if T_img > the number of media tokens in the corresponding input_ids (lang_x), |
|
only the first number of media tokens in lang_x are used |
|
lang_x: Language input ids, with media tokens denoting where |
|
visual media should be inserted. |
|
shape (B, T_txt) |
|
attention_mask: Attention mask. Defaults to None. |
|
labels: Labels. Defaults to None. |
|
shape (B, T_txt) |
|
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None. |
|
list of length = number of decoder layers in the LM |
|
exact implementation depends on LM, see Hugging Face docs |
|
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None. |
|
shape (B, T_txt) |
|
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None. |
|
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False. |
|
If True, includes key_values, media_locations, and vision_tokens in the output. |
|
""" |
|
assert not (past_vision_tokens is None) ^ ( |
|
past_media_locations is None |
|
), "past_vision_tokens and past_media_locations must both be None or both be not None" |
|
|
|
|
|
if vision_x is not None: |
|
vision_features = self._encode_vision_x(vision_x=vision_x) |
|
vision_tokens = self.vision_tokenizer(vision_features) |
|
else: |
|
vision_tokens = None |
|
|
|
|
|
new_inputs = self._prepare_inputs_for_forward( |
|
vision_tokens=vision_tokens, |
|
lang_x=lang_x, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
past_key_values=past_key_values, |
|
past_media_locations=past_media_locations, |
|
padding_side="right", |
|
past_vision_tokens=past_vision_tokens, |
|
) |
|
output = self.lang_model( |
|
**new_inputs, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
**kwargs, |
|
) |
|
|
|
|
|
|
|
output = self._postprocess_outputs_from_forward( |
|
output=output, |
|
lang_x=lang_x, |
|
vision_tokens=vision_tokens, |
|
use_cache=use_cache, |
|
past_vision_tokens=past_vision_tokens, |
|
past_media_locations=past_media_locations, |
|
) |
|
|
|
|
|
self._post_forward_hook() |
|
return output |
|
|
|
def _encode_vision_x(self, vision_x: torch.Tensor): |
|
""" |
|
Compute media tokens from vision input by passing it through vision encoder and conditioning language model. |
|
Args: |
|
vision_x: Vision input |
|
shape (B, T_img, F, C, H, W) |
|
Images in the same chunk are collated along T_img, and frames are collated along F |
|
Currently only F=1 is supported (single-frame videos) |
|
|
|
rearrange code based on https://github.com/dhansmair/flamingo-mini |
|
""" |
|
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" |
|
b, T, F = vision_x.shape[:3] |
|
|
|
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") |
|
with torch.no_grad(): |
|
if self.vision_encoder.__class__.__name__ == "TimmModel": |
|
vision_x = self.vision_encoder.trunk.forward_features(vision_x) |
|
elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']: |
|
vision_x = self.vision_encoder(vision_x).last_hidden_state |
|
else: |
|
vision_x = self.vision_encoder(vision_x)[1] |
|
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) |
|
return vision_x |
|
|
|
def _concat_vision_cache( |
|
self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache |
|
): |
|
""" |
|
Helper function to include the past vision tokens and past media locations in the output. |
|
""" |
|
if use_cache: |
|
if past_media_locations is not None and past_vision_tokens is not None: |
|
if vision_tokens is not None: |
|
updated_vision_tokens = torch.cat( |
|
[ |
|
past_vision_tokens, |
|
vision_tokens, |
|
], |
|
dim=1, |
|
) |
|
else: |
|
updated_vision_tokens = past_vision_tokens |
|
updated_media_locations = torch.cat( |
|
[ |
|
past_media_locations, |
|
lang_x == self.media_token_id, |
|
], |
|
dim=1, |
|
) |
|
else: |
|
updated_vision_tokens = vision_tokens |
|
updated_media_locations = lang_x == self.media_token_id |
|
|
|
else: |
|
updated_vision_tokens = None |
|
updated_media_locations = None |
|
|
|
return updated_vision_tokens, updated_media_locations |
|
|
|
def generate( |
|
self, |
|
vision_x: torch.Tensor, |
|
lang_x: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
past_key_values: Optional[ |
|
List[Union[torch.Tensor, Tuple[torch.Tensor]]] |
|
] = None, |
|
past_media_locations: Optional[torch.Tensor] = None, |
|
past_vision_tokens: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Generate text conditioned on vision and language inputs. |
|
Args: |
|
vision_x (torch.Tensor): Vision input |
|
shape (B, T_img, F, C, H, W) |
|
see documentation for forward |
|
lang_x (torch.Tensor): Language input |
|
shape (B, T_txt) |
|
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. |
|
**kwargs: see generate documentation in Hugging Face CausalLM models. |
|
Returns: |
|
torch.Tensor: lang_x with generated tokens appended to it |
|
""" |
|
num_beams = kwargs.pop("num_beams", 1) |
|
|
|
|
|
if vision_x is not None: |
|
vision_features = self._encode_vision_x(vision_x=vision_x) |
|
vision_tokens = self.vision_tokenizer(vision_features) |
|
else: |
|
vision_tokens = None |
|
|
|
|
|
|
|
|
|
new_inputs = self._prepare_inputs_for_forward( |
|
vision_tokens=vision_tokens, |
|
lang_x=lang_x, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
past_media_locations=past_media_locations, |
|
past_vision_tokens=past_vision_tokens, |
|
padding_side="left", |
|
num_beams=num_beams, |
|
) |
|
output = self.lang_model.generate( |
|
**new_inputs, |
|
past_key_values=past_key_values, |
|
num_beams=num_beams, |
|
use_cache=True, |
|
**kwargs, |
|
) |
|
self._post_forward_hook() |
|
return output |
|
|
|
@property |
|
def num_trainable_params(self): |
|
"""Print the number of trainable parameters""" |
|
return num_params(self, filter_to_trainable=True) |
|
|
|
def set_trainable(self): |
|
""" |
|
Freeze appropriate parameters in the model. |
|
""" |
|
raise NotImplementedError |
|
|
|
def group_params_by_weight_decay(self): |
|
""" |
|
Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay) |
|
""" |
|
params_with_wd, params_without_wd = [], [] |
|
for n, p in self.named_parameters(): |
|
if p.requires_grad: |
|
if self._should_apply_weight_decay(n): |
|
params_with_wd.append(p) |
|
else: |
|
params_without_wd.append(p) |
|
return params_with_wd, params_without_wd |
|
|
|
def _should_apply_weight_decay(self, parameter_name): |
|
""" |
|
Return whether weight decay should be applied to a parameter. |
|
""" |
|
raise NotImplementedError |
|
|
|
@property |
|
def special_tokens(self): |
|
""" |
|
Returns a dict mapping from the attribute name of a special token to its string format, |
|
e.g. "media_token": "<image>" |
|
""" |
|
assert ( |
|
"media_token" in self._special_tokens |
|
), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id" |
|
return self._special_tokens |
|
|
|
@property |
|
def special_token_ids(self): |
|
""" |
|
Returns a list of the special token ids |
|
""" |
|
return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens] |
|
|
|
def set_special_token_ids(self, string_to_ids): |
|
""" |
|
Args: |
|
string_to_ids (dict): mapping from token string to id |
|
""" |
|
assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys())) |
|
for att_name, token_str in self.special_tokens.items(): |
|
token_id = string_to_ids[token_str] |
|
setattr(self, f"{att_name}_id", token_id) |
|
setattr(self.lang_model, f"{att_name}_id", token_id) |
|
|
|
def init_gradient_checkpointing(self): |
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
checkpoint_wrapper, |
|
CheckpointWrapper, |
|
CheckpointImpl, |
|
apply_activation_checkpointing, |
|
) |
|
from functools import partial |
|
|
|
non_reentrant_wrapper = partial( |
|
checkpoint_wrapper, |
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
|
) |
|
apply_activation_checkpointing( |
|
self, |
|
checkpoint_wrapper_fn=non_reentrant_wrapper, |
|
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) |
|
and not isinstance(m, CheckpointWrapper), |
|
) |
|
|
|
|
|
class VLMWithLanguageStream(VLM): |
|
""" |
|
VLM that fuses modalities by inserting vision tokens directly into the language stream. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vision_encoder: nn.Module, |
|
vision_tokenizer: nn.Module, |
|
lang_model: nn.Module, |
|
initial_tokenizer_len: int, |
|
pad_token_id: int, |
|
decoder_layers_attr_name: str = None, |
|
gradient_checkpointing: bool = False, |
|
base_img_size: Optional[int] = None, |
|
): |
|
super().__init__( |
|
vision_encoder=vision_encoder, |
|
vision_tokenizer=vision_tokenizer, |
|
lang_model=lang_model, |
|
initial_tokenizer_len=initial_tokenizer_len, |
|
pad_token_id=pad_token_id, |
|
base_img_size=base_img_size, |
|
gradient_checkpointing=gradient_checkpointing, |
|
) |
|
self.decoder_layers_attr_name = decoder_layers_attr_name |
|
for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name): |
|
block._use_gradient_checkpointing = gradient_checkpointing |
|
|
|
@staticmethod |
|
def _make_modality_mutual_mask( |
|
attention_mask_2d: torch.Tensor, |
|
image_start_idx: int, |
|
text_start_idx: int, |
|
text_end_idx: int, |
|
input_ids_shape: torch.Size, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
): |
|
""" |
|
Make non-causal mask between modalities. |
|
""" |
|
tgt_len = input_ids_shape[0] |
|
mask = torch.full((tgt_len, tgt_len), 0, device=device) |
|
mask_cond = torch.arange(mask.size(-1), device=device) |
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1) |
|
|
|
|
|
mask[image_start_idx:text_start_idx, text_start_idx:text_end_idx] = 1 |
|
|
|
mask = mask.to(dtype) |
|
mask = mask[None, :, :].expand(1, tgt_len, tgt_len) |
|
|
|
expanded_mask = attention_mask_2d[None, None, :].expand(1, tgt_len, tgt_len).to(torch.float32) |
|
inverted_mask = 1.0 - expanded_mask |
|
expanded_attn_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(torch.float32).min) |
|
|
|
expanded_attn_mask = mask.masked_fill(expanded_attn_mask.bool(), 0) |
|
|
|
|
|
expanded_4d_mask = expanded_attn_mask |
|
|
|
return expanded_4d_mask |
|
|
|
def _prepare_inputs_for_forward( |
|
self, |
|
vision_tokens: torch.Tensor, |
|
lang_x: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
labels: torch.Tensor = None, |
|
past_key_values=None, |
|
vision_attention_mask: Optional[torch.Tensor] = None, |
|
past_media_locations: torch.Tensor = None, |
|
past_vision_tokens: torch.Tensor = None, |
|
padding_side: str = "left", |
|
num_beams: int = 1, |
|
): |
|
""" |
|
Insert the vision tokens directly into the language stream/ |
|
This requires us to modify the input_ids, attention_mask, and labels. |
|
[NOTE]: This function can be changed to fit the ablation setting of putting text before images. |
|
""" |
|
if past_key_values is not None: |
|
past_len = past_key_values[0][0].shape[2] |
|
assert attention_mask.shape[1] == past_len + lang_x.shape[1], ( |
|
"Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. " |
|
+ "Check that you've expanded the attention mask to account for past image tokens." |
|
) |
|
|
|
if vision_tokens is None: |
|
return { |
|
"input_ids": lang_x, |
|
"attention_mask": attention_mask, |
|
"labels": labels, |
|
} |
|
|
|
|
|
lang_embeds = self.lang_model.get_input_embeddings()(lang_x) |
|
|
|
|
|
B = lang_x.shape[0] |
|
has_labels = labels is not None |
|
multimodal_embeds = [] |
|
multimodal_attention_mask = [] |
|
multimodal_labels = [] if has_labels else None |
|
for i in range(B): |
|
|
|
image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0] |
|
|
|
|
|
|
|
question_token_idx = torch.where(lang_x[i] == 32001)[0] |
|
if len(question_token_idx) != 0: |
|
question_token_idx = question_token_idx[0] |
|
else: |
|
question_token_idx = 0 |
|
|
|
if len(image_token_idxs) == 0: |
|
multimodal_embeds.append(lang_embeds[i].clone()) |
|
new_attention_mask = self._make_modality_mutual_mask( |
|
attention_mask_2d=attention_mask[i].clone(), |
|
image_start_idx=0, |
|
text_start_idx=0, |
|
text_end_idx=question_token_idx, |
|
input_ids_shape=attention_mask[i].shape, |
|
dtype=attention_mask[i].dtype, |
|
device=attention_mask[i].device, |
|
) |
|
multimodal_attention_mask.append(new_attention_mask) |
|
if has_labels: |
|
multimodal_labels.append(labels[i].clone()) |
|
continue |
|
|
|
|
|
|
|
new_embed = lang_embeds[i].clone() |
|
new_attention_mask = ( |
|
attention_mask[i].clone() if attention_mask is not None else None |
|
) |
|
if has_labels: |
|
new_label = labels[i].clone() |
|
|
|
for img_num in range(len(image_token_idxs)): |
|
img_idx = image_token_idxs[img_num] |
|
assert ( |
|
vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis |
|
), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \ |
|
vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})" |
|
|
|
num_vis_tokens = self.num_tokens_per_vis |
|
vis_attention_mask = torch.ones( |
|
num_vis_tokens, dtype=torch.long |
|
).to(attention_mask.device) |
|
|
|
|
|
for j in range(img_num+1, len(image_token_idxs)): |
|
image_token_idxs[j] += num_vis_tokens |
|
|
|
new_embed = torch.cat( |
|
( |
|
new_embed[:img_idx], |
|
vision_tokens[i][img_num], |
|
new_embed[img_idx + 1 :], |
|
), |
|
dim=0, |
|
) |
|
new_attention_mask = torch.cat( |
|
( |
|
new_attention_mask[:img_idx], |
|
vis_attention_mask, |
|
new_attention_mask[img_idx + 1 :], |
|
), |
|
dim=0, |
|
) |
|
|
|
new_attention_mask = self._make_modality_mutual_mask( |
|
attention_mask_2d=new_attention_mask, |
|
image_start_idx=img_idx, |
|
text_start_idx=img_idx+len(vis_attention_mask), |
|
text_end_idx=question_token_idx+len(vis_attention_mask), |
|
input_ids_shape=new_attention_mask.shape, |
|
dtype=new_attention_mask.dtype, |
|
device=new_attention_mask.device, |
|
) |
|
|
|
if has_labels: |
|
new_label = torch.cat( |
|
( |
|
new_label[:img_idx], |
|
torch.ones(num_vis_tokens, dtype=torch.long).to( |
|
labels.device |
|
) |
|
* -100, |
|
new_label[img_idx + 1 :], |
|
), |
|
dim=0, |
|
) |
|
multimodal_embeds.append(new_embed) |
|
multimodal_attention_mask.append(new_attention_mask) |
|
if has_labels: |
|
multimodal_labels.append(new_label) |
|
|
|
|
|
multimodal_embeds = stack_with_padding( |
|
multimodal_embeds, |
|
padding_value=self.pad_token_id, |
|
padding_side=padding_side, |
|
) |
|
multimodal_attention_mask = stack_with_padding_2D_attention( |
|
multimodal_attention_mask, |
|
) |
|
if has_labels: |
|
multimodal_labels = stack_with_padding( |
|
multimodal_labels, |
|
padding_value=-100, |
|
padding_side=padding_side, |
|
) |
|
|
|
return { |
|
"inputs_embeds": multimodal_embeds, |
|
"attention_mask": multimodal_attention_mask, |
|
"labels": multimodal_labels, |
|
} |
|
|
|
def _postprocess_outputs_from_forward( |
|
self, |
|
output: CausalLMOutputWithPast, |
|
lang_x: torch.Tensor, |
|
vision_tokens: torch.Tensor, |
|
past_vision_tokens: torch.Tensor, |
|
past_media_locations: torch.Tensor, |
|
use_cache: bool = False, |
|
): |
|
|
|
updated_vision_tokens, updated_media_locations = self._concat_vision_cache( |
|
lang_x=lang_x, |
|
vision_tokens=vision_tokens, |
|
past_vision_tokens=past_vision_tokens, |
|
past_media_locations=past_media_locations, |
|
use_cache=use_cache, |
|
) |
|
|
|
|
|
logits = output.logits |
|
batch_logits = [] |
|
B, T_txt = lang_x.shape |
|
for i in range(B): |
|
sequence_logits = [] |
|
logits_j = 0 |
|
img_id = 0 |
|
for j in range(T_txt): |
|
if lang_x[i, j] != self.media_token_id: |
|
sequence_logits.append(logits[i, logits_j]) |
|
logits_j += 1 |
|
else: |
|
|
|
|
|
sequence_logits.append(logits[i, logits_j]) |
|
|
|
|
|
logits_j += vision_tokens[i][img_id].shape[0] |
|
img_id += 1 |
|
sequence_logits = torch.stack(sequence_logits, dim=0) |
|
batch_logits.append(sequence_logits) |
|
|
|
batch_logits = torch.stack(batch_logits, dim=0) |
|
|
|
assert batch_logits.shape[:2] == (B, T_txt) |
|
|
|
|
|
output = VLMOutputWithPast( |
|
loss=output.loss, |
|
logits=batch_logits, |
|
past_key_values=output.past_key_values, |
|
hidden_states=output.hidden_states, |
|
attentions=output.attentions, |
|
past_media_locations=updated_media_locations, |
|
past_vision_tokens=updated_vision_tokens, |
|
) |
|
|
|
return output |
|
|
|
def _post_forward_hook(self): |
|
pass |
|
|
|
def get_fsdp_lambda_fn(self): |
|
""" |
|
Returns the lambda function used to decide how to perform FSDP wrapping. |
|
""" |
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
CheckpointWrapper, |
|
) |
|
|
|
decoder_block_class = getattr_recursive( |
|
self.lang_model, self.decoder_layers_attr_name |
|
)[0].__class__ |
|
|
|
def lambda_fn(module: nn.Module): |
|
if getattr(module, "_use_gradient_checkpointing", False) and not isinstance( |
|
module, CheckpointWrapper |
|
): |
|
return False |
|
if module is self.vision_tokenizer: |
|
return True |
|
if isinstance(module, decoder_block_class): |
|
return True |
|
|
|
return lambda_fn |
|
|
|
def get_fsdp_wrapping_policy(self): |
|
""" |
|
Returns the policy used to decide how to perform FSDP wrapping. |
|
""" |
|
from torch.distributed.fsdp.wrap import _or_policy, _module_wrap_policy, transformer_auto_wrap_policy |
|
from open_clip.transformer import VisionTransformer, ResidualAttentionBlock |
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer |
|
from transformers.models.phi.modeling_phi import PhiDecoderLayer |
|
|
|
try: |
|
import importlib |
|
commit_hash = str(type(self.lang_model)).split('instruct.')[1].split('.modeling')[0] |
|
module_name = f"transformers_modules.microsoft.Phi-3-mini-128k-instruct.{commit_hash}.modeling_phi3" |
|
module = importlib.import_module(module_name) |
|
Phi3DecoderLayer = module.Phi3DecoderLayer |
|
import_phi3 = True |
|
except IndexError: |
|
import_phi3 = False |
|
|
|
|
|
|
|
|
|
if isinstance(self.vision_encoder, SiglipVisionModel): |
|
from transformers import SiglipVisionModel |
|
vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={SiglipVisionModel}) |
|
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipVisionTransformer, SiglipVisionEmbeddings, SiglipMultiheadAttentionPoolingHead |
|
|
|
transformer_layer_cls_vit = {SiglipEncoderLayer, SiglipVisionTransformer, SiglipVisionEmbeddings, SiglipMultiheadAttentionPoolingHead} |
|
vision_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls_vit) |
|
vision_wrap_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, vision_transformer_block_policy]) |
|
|
|
else: |
|
vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={VisionTransformer, TimmModel}) |
|
|
|
|
|
transformer_layer_cls_vit = {ResidualAttentionBlock, Block} |
|
|
|
vision_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls_vit) |
|
vision_wrap_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, vision_transformer_block_policy]) |
|
|
|
transformer_layer_cls={LlamaDecoderLayer, PhiDecoderLayer} |
|
if import_phi3: |
|
transformer_layer_cls.add(Phi3DecoderLayer) |
|
llm_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls) |
|
|
|
vis_tokenizer_policy = functools.partial(_module_wrap_policy, module_classes={LinearPatchProjection, PerceiverResampler}) |
|
return functools.partial( |
|
_or_policy, |
|
policies = [ |
|
vision_wrap_policy, |
|
llm_transformer_block_policy, |
|
vis_tokenizer_policy |
|
]) |
|
|
|
def group_params_by_weight_decay(self): |
|
""" |
|
Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay) |
|
""" |
|
params_with_wd, params_without_wd = [], [] |
|
for n, p in self.named_parameters(): |
|
if p.requires_grad: |
|
if "lang_model.model.embed_tokens" in n: |
|
params_without_wd.append(p) |
|
else: |
|
params_with_wd.append(p) |
|
return params_with_wd, params_without_wd |
|
|
|
@property |
|
def num_params_per_module(self): |
|
"""Print the number of parameters per module in the model""" |
|
return "\n".join( |
|
[ |
|
f"Vision encoder: {num_params(self.vision_encoder):,} parameters", |
|
f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters", |
|
f"Language model: {num_params(self.lang_model):,} parameters", |
|
] |
|
) |
|
|
|
@property |
|
def num_trainable_params_per_module(self): |
|
"""Print the number of trainable parameters per module in the model""" |
|
return "\n".join( |
|
[ |
|
f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters", |
|
f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters", |
|
f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters", |
|
] |
|
) |
|
|