Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import torch | |
import torch.nn.functional as F | |
from lavis.common.registry import registry | |
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer | |
class Blip2ITM(Blip2Qformer): | |
""" | |
BLIP Image-Text Matching (ITM) model. | |
Supported model types: | |
- pretrained: pretrained model | |
- coco: fintuned model on coco | |
Usage: | |
>>> from lavis.models import load_model | |
>>> model = load_model("blip2_image_text_matching", "pretrained") | |
>>> model = load_model("blip2_image_text_matching", "coco") | |
""" | |
def __init__( | |
self, | |
vit_model="eva_clip_g", | |
img_size=224, | |
drop_path_rate=0, | |
use_grad_checkpoint=False, | |
vit_precision="fp16", | |
freeze_vit=True, | |
num_query_token=32, | |
cross_attention_freq=2, | |
embed_dim=256, | |
max_txt_len=32, | |
): | |
super().__init__( | |
vit_model=vit_model, | |
img_size=img_size, | |
drop_path_rate=drop_path_rate, | |
use_grad_checkpoint=use_grad_checkpoint, | |
vit_precision=vit_precision, | |
freeze_vit=freeze_vit, | |
num_query_token=num_query_token, | |
cross_attention_freq=cross_attention_freq, | |
embed_dim=embed_dim, | |
max_txt_len=max_txt_len, | |
) | |
def forward(self, samples, match_head="itm"): | |
image = samples["image"] | |
caption = samples["text_input"] | |
with self.maybe_autocast(): | |
image_embeds = self.ln_vision(self.visual_encoder(image)) | |
image_embeds = image_embeds.float() | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
image.device | |
) | |
text = self.tokenizer( | |
caption, | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(image.device) | |
if match_head == "itm": | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( | |
image.device | |
) | |
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) | |
output_itm = self.Qformer.bert( | |
text.input_ids, | |
query_embeds=query_tokens, | |
attention_mask=attention_mask, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] | |
itm_logit = self.itm_head(itm_embeddings) | |
itm_logit = itm_logit.mean(dim=1) | |
return itm_logit | |
elif match_head == "itc": | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
image_feats = F.normalize( | |
self.vision_proj(query_output.last_hidden_state), dim=-1 | |
) | |
text_output = self.Qformer.bert( | |
text.input_ids, | |
attention_mask=text.attention_mask, | |
return_dict=True, | |
) | |
text_feat = F.normalize( | |
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 | |
) | |
sims = torch.bmm(image_feats, text_feat.unsqueeze(-1)) | |
sim, _ = torch.max(sims, dim=1) | |
return sim | |