HUANG-Stephanie's picture
Upload 88 files
9ff79dc verified
import os
from typing import Optional
import torch
from transformers import SiglipModel
class SigLIP(SiglipModel):
def forward(self, *args, **kwargs):
"""
Forward pass through Llama and the linear layer for dimensionality reduction
Args:
- input_ids (torch.LongTensor): The input tokens tensor.
- attention_mask (torch.LongTensor): The attention mask tensor.
Returns:
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
"""
return self.forward_branch(*args, **kwargs)
def forward_branch(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is not None:
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
outputs = self.vision_model(
pixel_values=pixel_values.to(dtype=self.dtype),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
else:
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
embeds = outputs[1]
# normalized features
embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
return embeds
class ColSigLIP(SiglipModel):
def __init__(self, config):
super(ColSigLIP, self).__init__(config=config)
self.dim = 128
self.custom_vision_proj = torch.nn.Linear(self.config.vision_config.hidden_size, self.dim)
self.custom_text_proj = torch.nn.Linear(self.config.text_config.hidden_size, self.dim)
self.main_input_name = "doc_input_ids"
def forward(self, *args, **kwargs):
"""
Forward pass through Llama and the linear layer for dimensionality reduction
Args:
- input_ids (torch.LongTensor): The input tokens tensor.
- attention_mask (torch.LongTensor): The attention mask tensor.
Returns:
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
"""
return self.forward_branch(*args, **kwargs)
def forward_branch(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is not None:
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
outputs = self.vision_model(
pixel_values=pixel_values.to(dtype=self.dtype),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
last_hidden_states = outputs.last_hidden_state
proj = self.custom_vision_proj(last_hidden_states)
# normalize l2 norm
proj = proj / proj.norm(dim=-1, keepdim=True)
else:
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs.last_hidden_state
proj = self.custom_text_proj(last_hidden_states)
# normalize l2 norm
proj = proj / proj.norm(dim=-1, keepdim=True)
proj = proj * attention_mask.unsqueeze(-1)
# normalized features
return proj