John6666's picture
Upload 351 files
e84842d verified
raw
history blame
7.09 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import torch
import torch.nn.functional as F
from lavis.common.dist_utils import download_cached_file
from lavis.common.registry import registry
from lavis.common.utils import get_abs_path, is_url
from lavis.models.base_model import MomentumDistilationMixin
from lavis.models.blip_models.blip import BlipBase
from lavis.models.blip_models.blip_outputs import BlipIntermediateOutput, BlipOutput
from lavis.models.blip_models.nlvr_encoder import BertModel
from lavis.models.vit import VisionTransformerEncoder, interpolate_pos_embed
from torch import nn
from transformers import BertConfig
@registry.register_model("blip_nlvr")
class BlipNLVR(BlipBase, MomentumDistilationMixin):
"""
Class for BLIP NLVR model.
Supported model types:
- base: model with pre-trained BLIP weights, used as initialization for fine-tuning.
- nlvr: finetuned model on NLVR2 dataset.
Usage:
>>> from lavis.models import load_model
>>> model = load_model("blip_nlvr", "nlvr")
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"nlvr": "configs/models/blip_nlvr.yaml",
}
def __init__(self, image_encoder, text_encoder, num_classes):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.visual_encoder = image_encoder
self.text_encoder = text_encoder
hidden_size = text_encoder.config.hidden_size
self.cls_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_classes),
)
def forward(self, samples, is_train=True):
"""
Forward function for training and evaluation.
Args:
samples (dict): a dict of input samples, which contains the following keys:
- image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384.
- image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384.
- text_input (list): list of strings, each string is a natural language sentence.
- label (torch.LongTensor): ground truth label with shape (batch_size,).
is_train (bool): whether the model is in training mode.
If True, the model will return the loss;
If False, the model will return the prediction.
Examples:
>>> import torch
>>> from lavis.models import load_model
>>> model = load_model("blip_nlvr", "nlvr")
>>> samples = {
... "image0": torch.randn(2, 3, 384, 384),
... "image1": torch.randn(2, 3, 384, 384),
... "text_input": ["there is a ferret in tall grass", "there are lips in one of the images"],
... "label": torch.tensor([0, 1]),
... }
>>> output = model(samples)
>>> output.keys()
odict_keys(['intermediate_output', 'loss'])
"""
text = samples["text_input"]
text = self.tokenizer(text, padding="longest", return_tensors="pt").to(
self.device
)
text.input_ids[:, 0] = self.tokenizer.enc_token_id
targets = samples["label"]
image0 = samples["image0"]
image1 = samples["image1"]
images = torch.cat([image0, image1], dim=0)
image_embeds = self.visual_encoder.forward_features(images)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
self.device
)
image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0))
encoder_output = self.text_encoder(
text.input_ids,
attention_mask=text.attention_mask,
encoder_hidden_states=[image0_embeds, image1_embeds],
encoder_attention_mask=[
image_atts[: image0_embeds.size(0)],
image_atts[image0_embeds.size(0) :],
],
return_dict=True,
)
prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])
if is_train:
loss = F.cross_entropy(prediction, targets)
# return {"loss": loss}
return BlipOutput(
loss=loss,
intermediate_output=BlipIntermediateOutput(
image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0),
encoder_output=encoder_output,
),
)
else:
return {"predictions": prediction, "targets": targets}
def predict(self, samples):
output = self.forward(samples, is_train=False)
return output
@classmethod
def from_config(cls, cfg=None):
image_encoder = VisionTransformerEncoder.from_config(cfg)
# text encoder + multimodal encoder
bert_config = BertConfig.from_json_file(get_abs_path(cfg["med_config_path"]))
text_encoder = BertModel(config=bert_config, add_pooling_layer=False)
num_classes = cfg.get("num_classes", 3)
assert num_classes > 1, "Invalid number of classes provided, found {}".format(
num_classes
)
model = cls(
image_encoder=image_encoder,
text_encoder=text_encoder,
num_classes=num_classes,
)
model.load_checkpoint_from_config(cfg)
return model
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed(
state_dict["visual_encoder.pos_embed"], self.visual_encoder
)
for key in list(state_dict.keys()):
if "crossattention.self." in key:
new_key0 = key.replace("self", "self0")
new_key1 = key.replace("self", "self1")
state_dict[new_key0] = state_dict[key]
state_dict[new_key1] = state_dict[key]
elif "crossattention.output.dense." in key:
new_key0 = key.replace("dense", "dense0")
new_key1 = key.replace("dense", "dense1")
state_dict[new_key0] = state_dict[key]
state_dict[new_key1] = state_dict[key]
msg = self.load_state_dict(state_dict, strict=False)
print("load checkpoint from %s" % url_or_filename)
print(f"missing keys {msg.missing_keys}")
return msg