jchwenger's picture
Upload 351 files (#2)
d9272c6 verified
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from dataclasses import dataclass
from typing import Optional
import torch
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
ModelOutput,
)
@dataclass
class AlbefSimilarity(ModelOutput):
sim_i2t: torch.FloatTensor = None
sim_t2i: torch.FloatTensor = None
sim_i2t_m: Optional[torch.FloatTensor] = None
sim_t2i_m: Optional[torch.FloatTensor] = None
sim_i2t_targets: Optional[torch.FloatTensor] = None
sim_t2i_targets: Optional[torch.FloatTensor] = None
@dataclass
class AlbefIntermediateOutput(ModelOutput):
# uni-modal features
image_embeds: torch.FloatTensor = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds_m: Optional[torch.FloatTensor] = None
text_embeds_m: Optional[torch.FloatTensor] = None
# intermediate outputs of multimodal encoder
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
encoder_output_m: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
itm_logits: Optional[torch.FloatTensor] = None
itm_labels: Optional[torch.LongTensor] = None
# intermediate outputs of multimodal decoder
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
decoder_labels: Optional[torch.LongTensor] = None
@dataclass
class AlbefOutput(ModelOutput):
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
sims: Optional[AlbefSimilarity] = None
intermediate_output: AlbefIntermediateOutput = None
loss: Optional[torch.FloatTensor] = None
loss_itc: Optional[torch.FloatTensor] = None
loss_itm: Optional[torch.FloatTensor] = None
loss_mlm: Optional[torch.FloatTensor] = None
@dataclass
class AlbefOutputWithLogits(AlbefOutput):
logits: torch.FloatTensor = None
logits_m: torch.FloatTensor = None
@dataclass
class AlbefOutputFeatures(ModelOutput):
"""
Data class of features from AlbefFeatureExtractor.
Args:
image_embeds: `torch.FloatTensor` of shape `(batch_size, num_patches+1, embed_dim)`, `optional`
image_features: `torch.FloatTensor` of shape `(batch_size, num_patches+1, feature_dim)`, `optional`
text_embeds: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, embed_dim)`, `optional`
text_features: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, feature_dim)`, `optional`
The first embedding or feature is for the [CLS] token.
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
"""
image_embeds: Optional[torch.FloatTensor] = None
image_embeds_proj: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
text_embeds_proj: Optional[torch.FloatTensor] = None
multimodal_embeds: Optional[torch.FloatTensor] = None