Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import warnings | |
import torch | |
import torch.nn.functional as F | |
from lavis.common.registry import registry | |
from lavis.models.blip_models.blip import BlipBase | |
from lavis.models.blip_models.blip_outputs import BlipOutputFeatures | |
from lavis.models.med import XBertEncoder | |
from lavis.models.vit import VisionTransformerEncoder | |
from torch import nn | |
class BlipFeatureExtractor(BlipBase): | |
""" | |
Class for BLIP feature extractor. | |
Supported model types: | |
- base: BLIP base model with pre-trained weights from capfilt by BLIP large model. | |
Usage: | |
>>> from lavis.models import load_model | |
>>> model = load_model("blip_feature_extractor", "base") | |
""" | |
PRETRAINED_MODEL_CONFIG_DICT = { | |
"base": "configs/models/blip_feature_extractor_base.yaml", | |
# "large": "configs/models/blip_feature_extractor_large.yaml", | |
} | |
def __init__(self, image_encoder, text_encoder, embed_dim, max_txt_len=40): | |
super().__init__() | |
self.tokenizer = self.init_tokenizer() | |
self.visual_encoder = image_encoder | |
self.text_encoder = text_encoder | |
# creating projection layers for ITC | |
text_width = text_encoder.config.hidden_size | |
vision_width = image_encoder.vision_width | |
self.vision_proj = nn.Linear(vision_width, embed_dim) | |
self.text_proj = nn.Linear(text_width, embed_dim) | |
self.max_txt_len = max_txt_len | |
self.temp = nn.Parameter(0.07 * torch.ones([])) | |
def extract_features(self, samples, mode="multimodal"): | |
""" | |
Extract features for multimodal or unimodal samples. | |
Args: | |
samples (dict): A dictionary of samples, containing the following keys: | |
- image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image. | |
Raw images should be preprocessed before being passed to feature extractor. | |
- text_input (list): A list of strings containing the text, length B. | |
mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image". | |
If "multimodal", return image features and multimodal features; | |
if "text", return text features; | |
if "image", return image features. | |
Default: "multimodal". | |
Returns: | |
BlipOutputFeatures: A BlipOutputFeatures object containing the features. | |
See lavis/models/blip_models/blip_outputs.py for more details. | |
Examples: | |
```python | |
>>> from PIL import Image | |
>>> from lavis.models import load_model_and_preprocess | |
>>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") | |
>>> caption = "a large fountain spewing water into the air" | |
>>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_feature_extractor", is_eval=True) | |
>>> image = vis_processors["eval"](raw_image).unsqueeze(0) | |
>>> text_input = txt_processors["eval"](caption) | |
>>> sample = {"image": image, "text_input": [text_input]} | |
>>> features_multimodal = model.extract_features(sample) | |
>>> features_multimodal.keys() | |
odict_keys(['image_embeds', 'multimodal_embeds']) | |
>>> features_multimodal.image_embeds.shape | |
torch.Size([1, 197, 768]) | |
>>> features_multimodal.multimodal_embeds.shape | |
torch.Size([1, 12, 768]) | |
>>> features_text = model.extract_features(sample, mode="text") | |
>>> features_text.keys() | |
odict_keys(['text_embeds', 'text_features']) | |
>>> features_text.text_embeds.shape | |
torch.Size([1, 12, 768]) | |
>>> features_text.text_features.shape | |
torch.Size([1, 12, 256]) | |
>>> features_image = model.extract_features(sample, mode="image") | |
>>> features_image.keys() | |
odict_keys(['image_embeds', 'image_features']) | |
>>> features_image.image_embeds.shape | |
torch.Size([1, 197, 768]) | |
>>> features_image.image_features.shape | |
torch.Size([1, 197, 256]) | |
``` | |
""" | |
image = samples.get("image") | |
caption = samples.get("text_input") | |
# assert mode is one of "image", "text", "multimodal" | |
assert mode in [ | |
"image", | |
"text", | |
"multimodal", | |
], "mode must be one of 'image', 'text', 'multimodal'" | |
# initalize output | |
image_embeds, text_embeds, multimodal_embeds = None, None, None | |
image_features, text_features = None, None | |
if mode == "image": | |
assert ( | |
image is not None | |
), "Image is not provided for mode 'image' or 'multimodal'" | |
# return image features | |
image_embeds = self.visual_encoder.forward_features(image) | |
image_features = self.vision_proj(image_embeds) | |
image_features = F.normalize(image_features, dim=-1) | |
elif mode == "text": | |
assert ( | |
caption is not None | |
), "text input is None for mode 'text' or 'multimodal'" | |
text = self.tokenizer(caption, return_tensors="pt", padding=True).to( | |
self.device | |
) | |
# return text features | |
text_output = self.text_encoder( | |
text.input_ids, | |
attention_mask=text.attention_mask, | |
return_dict=True, | |
mode="text", | |
) | |
text_embeds = text_output.last_hidden_state | |
text_features = self.text_proj(text_embeds) | |
text_features = F.normalize(text_features, dim=-1) | |
elif mode == "multimodal": | |
# return multimodel features | |
image_embeds = self.visual_encoder.forward_features(image) | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
self.device | |
) | |
text = self.tokenizer(caption, return_tensors="pt", padding=True).to( | |
self.device | |
) | |
text.input_ids[:, 0] = self.tokenizer.enc_token_id | |
output = self.text_encoder( | |
text.input_ids, | |
attention_mask=text.attention_mask, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
multimodal_embeds = output.last_hidden_state | |
return BlipOutputFeatures( | |
image_embeds=image_embeds, | |
image_embeds_proj=image_features, | |
text_embeds=text_embeds, | |
text_embeds_proj=text_features, | |
multimodal_embeds=multimodal_embeds, | |
) | |
def from_config(cls, cfg=None): | |
# set from_pretrained=True to load weights for 'bert-base-uncased' | |
image_encoder = VisionTransformerEncoder.from_config(cfg) | |
text_encoder = XBertEncoder.from_config(cfg) | |
embed_dim = cfg.get("embed_dim", 256) | |
max_txt_len = cfg.get("max_txt_len", 30) | |
model = cls( | |
image_encoder=image_encoder, | |
text_encoder=text_encoder, | |
embed_dim=embed_dim, | |
max_txt_len=max_txt_len, | |
) | |
# load pre-trained weights | |
pretrain_path = cfg.get("pretrained", None) | |
if pretrain_path is not None: | |
msg = model.load_from_pretrained(url_or_filename=pretrain_path) | |
else: | |
warnings.warn("No pretrained weights are loaded.") | |
return model | |