pix2pix-zero-demo-CPU / lavis /models /blip_models /blip_image_text_matching.py
John6666's picture
Upload 351 files
e84842d verified
raw
history blame
7.44 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.models.blip_models.blip import BlipBase
from torch import nn
from lavis.models.med import XBertEncoder
from lavis.models.vit import VisionTransformerEncoder
@registry.register_model("blip_image_text_matching")
class BlipITM(BlipBase):
"""
BLIP Image-Text Matching (ITM) model.
Supported model types:
- base: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).
- large: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split).
Usage:
>>> from lavis.models import load_model
>>> model = load_model("blip_image_text_matching", "base")
>>> model = load_model("blip_image_text_matching", "large")
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"base": "configs/models/blip_itm_base.yaml",
"large": "configs/models/blip_itm_large.yaml",
}
def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=35):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.text_encoder = text_encoder
self.visual_encoder = image_encoder
self.max_txt_len = max_txt_len
# creating projection layers for ITC
text_width = text_encoder.config.hidden_size
vision_width = image_encoder.vision_width
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
def forward(self, samples, match_head="itm"):
image = samples["image"]
caption = samples["text_input"]
image_embeds = self.visual_encoder.forward_features(image)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
text = self.tokenizer(
caption,
padding="longest",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(image.device)
if match_head == "itm":
encoder_input_ids = text.input_ids.clone()
encoder_input_ids[:, 0] = self.tokenizer.enc_token_id # extra code
output = self.text_encoder(
encoder_input_ids,
attention_mask=text.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
itm_output = self.itm_head(output.last_hidden_state[:, 0, :])
return itm_output
elif match_head == "itc":
text_output = self.text_encoder(
text.input_ids,
attention_mask=text.attention_mask,
return_dict=True,
mode="text",
)
image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
text_feat = F.normalize(
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
)
sim = image_feat @ text_feat.t()
return sim
def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head='itm'):
# breakpoint()
encoder_input_ids = encoder_input_ids.clone()
encoder_input_ids = encoder_input_ids[:, 3:]
text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long()
if match_head == 'itm':
# encoder_input_ids = encoder_input_ids.clone()
encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
output = self.text_encoder(encoder_input_ids,
attention_mask=text_attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
# print(output.last_hidden_state.shape)
itm_output = self.itm_head(output.last_hidden_state[:, 0, :])
itm_output = F.softmax(itm_output, dim=1)[:,1]
return itm_output #, mask, token_length
elif match_head == 'itc':
encoder_input_ids[:, 0] = self.tokenizer.cls_token_id
text_output = self.text_encoder(encoder_input_ids, attention_mask=text_attention_mask,
return_dict=True, mode='text')
image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
sim = image_feat @ text_feat.t()
return sim
@classmethod
def from_config(cls, cfg=None):
image_encoder = VisionTransformerEncoder.from_config(cfg)
text_encoder = XBertEncoder.from_config(cfg)
embed_dim = cfg.get("embed_dim", 256)
max_txt_len = cfg.get("max_txt_len", 35)
model = cls(
image_encoder=image_encoder,
text_encoder=text_encoder,
embed_dim=embed_dim,
max_txt_len=max_txt_len,
)
model.load_checkpoint_from_config(cfg)
return model
def compute_gradcam(model, visual_input, text_input, tokenized_text, block_num=6):
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
loss = output[:, 1].sum()
model.zero_grad()
loss.backward()
with torch.no_grad():
mask = tokenized_text.attention_mask.view(
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
) # (bsz,1,token_len, 1,1)
token_length = tokenized_text.attention_mask.sum(dim=-1) - 2
token_length = token_length.cpu()
# grads and cams [bsz, num_head, seq_len, image_patch]
grads = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attn_gradients()
cams = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attention_map()
# assume using vit with 576 num image patch
cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
grads = (
grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)
* mask
)
gradcams = cams * grads
gradcam_list = []
for ind in range(visual_input.size(0)):
token_length_ = token_length[ind]
gradcam = gradcams[ind].mean(0).cpu().detach()
# [enc token gradcam, average gradcam across token, gradcam for individual token]
gradcam = torch.cat(
(
gradcam[0:1, :],
gradcam[1 : token_length_ + 1, :].sum(dim=0, keepdim=True)
/ token_length_,
gradcam[1:, :],
)
)
gradcam_list.append(gradcam)
return gradcam_list, output