|
import torch |
|
from torch import nn |
|
from transformers import Idefics3Model, Idefics3ForConditionalGeneration |
|
from typing import Dict, Any, List, Optional, Union, Tuple |
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
|
from transformers.utils import add_start_docstrings_to_model_forward, logging |
|
from transformers.models.idefics3.modeling_idefics3 import IDEFICS3_INPUTS_DOCSTRING, Idefics3BaseModelOutputWithPast |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class SmolVLMModel(Idefics3Model): |
|
""" |
|
A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger |
|
in forward. Instead, we override inputs_merger here with custom logic. |
|
""" |
|
def inputs_merger( |
|
self, |
|
input_ids: torch.LongTensor, |
|
inputs_embeds: torch.Tensor, |
|
image_hidden_states: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Merge text embeddings with image embeddings out-of-place (no in-place indexing). |
|
|
|
The shapes are something like: |
|
- input_ids: (B, T) |
|
- inputs_embeds: (B, T, D) |
|
- image_hidden_states:(N, S, D) where N is total images across the batch, |
|
S is #patches (or #slots) per image, D is embedding dim. |
|
|
|
Logic: |
|
1) For each sample in the batch, find <image> tokens in the text. |
|
2) If zero <image> tokens => text-only. Concatenate a zero-length slice |
|
from image_hidden_states but do NOT advance the offset. This ensures |
|
the model's image encoder is still in the computation graph, but we |
|
skip "consuming" any image block for a text-only sample. |
|
3) If there are <image> tokens, they appear in multiples of S for each image |
|
(because each image is S embeddings). We chunk those positions into groups |
|
of S. For each chunk => we consume one block from image_hidden_states[offset] |
|
(which is shape (S, D)), and place each row into the text in place of a token. |
|
|
|
Returns: |
|
A tensor of (B, T, D). |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
B, T, D_text = inputs_embeds.shape |
|
N, S, D_img = image_hidden_states.shape |
|
if D_text != D_img: |
|
raise ValueError( |
|
f"Text embedding dim {D_text} != image embedding dim {D_img}" |
|
) |
|
|
|
|
|
|
|
|
|
image_offset = 0 |
|
|
|
|
|
merged_outputs: List[torch.Tensor] = [] |
|
|
|
|
|
|
|
|
|
for b_idx, (cur_ids, cur_embeds) in enumerate(zip(input_ids, inputs_embeds)): |
|
|
|
image_positions = (cur_ids == self.image_token_id).nonzero(as_tuple=True)[0] |
|
num_image_tokens = len(image_positions) |
|
|
|
|
|
if num_image_tokens == 0: |
|
|
|
|
|
empty_slice = image_hidden_states[0][:0, :] |
|
|
|
|
|
merged_text_only = torch.cat([cur_embeds, empty_slice], dim=0) |
|
merged_outputs.append(merged_text_only) |
|
continue |
|
|
|
|
|
|
|
|
|
if num_image_tokens % S != 0: |
|
raise ValueError( |
|
f"Sample {b_idx} has {num_image_tokens} <image> tokens, not a multiple of S={S}. " |
|
"Cannot map them to blocks of shape (S, D)." |
|
) |
|
|
|
|
|
positions_list = image_positions.tolist() |
|
|
|
chunks = [ |
|
positions_list[i : i + S] |
|
for i in range(0, num_image_tokens, S) |
|
] |
|
|
|
|
|
segments = [] |
|
text_start = 0 |
|
|
|
|
|
for chunk in chunks: |
|
|
|
cur_block = image_hidden_states[image_offset] |
|
image_offset += 1 |
|
|
|
|
|
for i_s, pos in enumerate(chunk): |
|
|
|
if pos > text_start: |
|
segments.append(cur_embeds[text_start:pos]) |
|
|
|
row_of_block = cur_block[i_s : i_s + 1, :] |
|
segments.append(row_of_block) |
|
|
|
text_start = pos + 1 |
|
|
|
|
|
if text_start < T: |
|
segments.append(cur_embeds[text_start:]) |
|
|
|
|
|
merged_sample = torch.cat(segments, dim=0) |
|
merged_outputs.append(merged_sample) |
|
|
|
merged_outputs = torch.stack(merged_outputs) |
|
|
|
return merged_outputs |
|
|
|
|
|
@add_start_docstrings_to_model_forward( |
|
""" |
|
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to |
|
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where |
|
max_num_images is the maximum number of images among the batch_size samples in the batch. |
|
Padding images are not needed beyond padding the pixel_values at the entrance of the model. |
|
For efficiency, we only pass through the vision_model's forward the real images by |
|
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where |
|
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3. |
|
""", |
|
IDEFICS3_INPUTS_DOCSTRING, |
|
) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
pixel_attention_mask: Optional[torch.BoolTensor] = None, |
|
image_hidden_states: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, Idefics3BaseModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if self.training and self.text_model.gradient_checkpointing and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
if input_ids is not None: |
|
batch_size, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
past_seen_tokens = 0 |
|
if use_cache: |
|
if past_key_values is None: |
|
past_key_values = DynamicCache() |
|
past_seen_tokens = past_key_values.get_seq_length() |
|
|
|
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: |
|
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) |
|
|
|
|
|
if pixel_values is not None and image_hidden_states is not None: |
|
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") |
|
elif pixel_values is not None: |
|
batch_size, num_images, num_channels, height, width = pixel_values.shape |
|
pixel_values = pixel_values.to(dtype=self.dtype) |
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) |
|
|
|
|
|
nb_values_per_image = pixel_values.shape[1:].numel() |
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image |
|
|
|
if not any(real_images_inds): |
|
|
|
real_images_inds[0] = True |
|
|
|
pixel_values = pixel_values[real_images_inds].contiguous() |
|
|
|
|
|
if pixel_attention_mask is None: |
|
pixel_attention_mask = torch.ones( |
|
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), |
|
dtype=torch.bool, |
|
device=pixel_values.device, |
|
) |
|
else: |
|
|
|
pixel_attention_mask = pixel_attention_mask.view( |
|
batch_size * num_images, *pixel_attention_mask.shape[2:] |
|
) |
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() |
|
|
|
patch_size = self.config.vision_config.patch_size |
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) |
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) |
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() |
|
|
|
|
|
image_hidden_states = self.vision_model( |
|
pixel_values=pixel_values, |
|
patch_attention_mask=patch_attention_mask, |
|
).last_hidden_state |
|
|
|
|
|
image_hidden_states = self.connector(image_hidden_states) |
|
|
|
elif image_hidden_states is not None: |
|
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) |
|
|
|
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: |
|
|
|
|
|
inputs_embeds = self.inputs_merger( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
image_hidden_states=image_hidden_states, |
|
) |
|
|
|
outputs = self.text_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [*outputs, image_hidden_states] if v is not None) |
|
|
|
return Idefics3BaseModelOutputWithPast( |
|
last_hidden_state=outputs.last_hidden_state, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
image_hidden_states=image_hidden_states, |
|
) |
|
|
|
|
|
|
|
|
|
class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration): |
|
""" |
|
A subclass of Idefics3ForConditionalGeneration that uses MyIdefics3Model |
|
instead of the default Idefics3Model. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
self.model = SmolVLMModel(config) |
|
|
|
|
|
self.lm_head = nn.Linear( |
|
config.text_config.hidden_size, config.text_config.vocab_size, bias=False |
|
) |
|
|
|
|
|
self.post_init() |