import os from typing import Optional import torch from transformers import SiglipModel class SigLIP(SiglipModel): def forward(self, *args, **kwargs): """ Forward pass through Llama and the linear layer for dimensionality reduction Args: - input_ids (torch.LongTensor): The input tokens tensor. - attention_mask (torch.LongTensor): The attention mask tensor. Returns: - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) """ return self.forward_branch(*args, **kwargs) def forward_branch( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is not None: # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. outputs = self.vision_model( pixel_values=pixel_values.to(dtype=self.dtype), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) else: outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) embeds = outputs[1] # normalized features embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True) return embeds class ColSigLIP(SiglipModel): def __init__(self, config): super(ColSigLIP, self).__init__(config=config) self.dim = 128 self.custom_vision_proj = torch.nn.Linear(self.config.vision_config.hidden_size, self.dim) self.custom_text_proj = torch.nn.Linear(self.config.text_config.hidden_size, self.dim) self.main_input_name = "doc_input_ids" def forward(self, *args, **kwargs): """ Forward pass through Llama and the linear layer for dimensionality reduction Args: - input_ids (torch.LongTensor): The input tokens tensor. - attention_mask (torch.LongTensor): The attention mask tensor. Returns: - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) """ return self.forward_branch(*args, **kwargs) def forward_branch( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is not None: # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. outputs = self.vision_model( pixel_values=pixel_values.to(dtype=self.dtype), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) last_hidden_states = outputs.last_hidden_state proj = self.custom_vision_proj(last_hidden_states) # normalize l2 norm proj = proj / proj.norm(dim=-1, keepdim=True) else: outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_states = outputs.last_hidden_state proj = self.custom_text_proj(last_hidden_states) # normalize l2 norm proj = proj / proj.norm(dim=-1, keepdim=True) proj = proj * attention_mask.unsqueeze(-1) # normalized features return proj