""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from dataclasses import dataclass from typing import Optional import torch from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, ModelOutput, ) @dataclass class AlbefSimilarity(ModelOutput): sim_i2t: torch.FloatTensor = None sim_t2i: torch.FloatTensor = None sim_i2t_m: Optional[torch.FloatTensor] = None sim_t2i_m: Optional[torch.FloatTensor] = None sim_i2t_targets: Optional[torch.FloatTensor] = None sim_t2i_targets: Optional[torch.FloatTensor] = None @dataclass class AlbefIntermediateOutput(ModelOutput): # uni-modal features image_embeds: torch.FloatTensor = None text_embeds: Optional[torch.FloatTensor] = None image_embeds_m: Optional[torch.FloatTensor] = None text_embeds_m: Optional[torch.FloatTensor] = None # intermediate outputs of multimodal encoder encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None encoder_output_m: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None itm_logits: Optional[torch.FloatTensor] = None itm_labels: Optional[torch.LongTensor] = None # intermediate outputs of multimodal decoder decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None decoder_labels: Optional[torch.LongTensor] = None @dataclass class AlbefOutput(ModelOutput): # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. sims: Optional[AlbefSimilarity] = None intermediate_output: AlbefIntermediateOutput = None loss: Optional[torch.FloatTensor] = None loss_itc: Optional[torch.FloatTensor] = None loss_itm: Optional[torch.FloatTensor] = None loss_mlm: Optional[torch.FloatTensor] = None @dataclass class AlbefOutputWithLogits(AlbefOutput): logits: torch.FloatTensor = None logits_m: torch.FloatTensor = None @dataclass class AlbefOutputFeatures(ModelOutput): """ Data class of features from AlbefFeatureExtractor. Args: image_embeds: `torch.FloatTensor` of shape `(batch_size, num_patches+1, embed_dim)`, `optional` image_features: `torch.FloatTensor` of shape `(batch_size, num_patches+1, feature_dim)`, `optional` text_embeds: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, embed_dim)`, `optional` text_features: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, feature_dim)`, `optional` The first embedding or feature is for the [CLS] token. Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. """ image_embeds: Optional[torch.FloatTensor] = None image_embeds_proj: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None text_embeds_proj: Optional[torch.FloatTensor] = None multimodal_embeds: Optional[torch.FloatTensor] = None