Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import warnings | |
| import torch | |
| import torch.nn.functional as F | |
| from lavis.common.registry import registry | |
| from lavis.common.utils import get_abs_path | |
| from lavis.models.albef_models import AlbefBase | |
| from lavis.models.albef_models.albef_outputs import AlbefOutputFeatures | |
| from lavis.models.med import BertForMaskedLM | |
| from lavis.models.vit import VisionTransformerEncoder | |
| from torch import nn | |
| from transformers import BertConfig | |
| class AlbefFeatureExtractor(AlbefBase): | |
| PRETRAINED_MODEL_CONFIG_DICT = { | |
| "base": "configs/models/albef_feature_extractor.yaml", | |
| } | |
| def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=30): | |
| super().__init__() | |
| self.tokenizer = self.init_tokenizer() | |
| self.visual_encoder = image_encoder | |
| self.text_encoder = text_encoder | |
| text_width = text_encoder.config.hidden_size | |
| vision_width = image_encoder.vision_width | |
| self.embed_dim = embed_dim | |
| self.vision_proj = nn.Linear(vision_width, embed_dim) | |
| self.text_proj = nn.Linear(text_width, embed_dim) | |
| self.max_txt_len = max_txt_len | |
| self.temp = nn.Parameter(0.07 * torch.ones([])) | |
| def extract_features(self, samples, mode="multimodal"): | |
| """ | |
| Extract features for multimodal or unimodal samples. | |
| Args: | |
| samples (dict): A dictionary of samples, containing the following keys: | |
| - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image. | |
| Raw images should be preprocessed before being passed to feature extractor. | |
| - text_input (list): A list of strings containing the text, length B. | |
| mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image". | |
| If "multimodal", return image features and multimodal features; | |
| if "text", return text features; | |
| if "image", return image features. | |
| Default: "multimodal". | |
| Returns: | |
| An AlbefOutputFeatures object, see lavis/models/albef_models/albef_outputs.py for details. | |
| Examples: | |
| ```python | |
| >>> from PIL import Image | |
| >>> from lavis.models import load_model_and_preprocess | |
| >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") | |
| >>> caption = "a large fountain spewing water into the air" | |
| >>> model, vis_processors, txt_processors = load_model_and_preprocess("albef_feature_extractor", is_eval=True) | |
| >>> image = vis_processors["eval"](raw_image).unsqueeze(0) | |
| >>> text_input = txt_processors["eval"](caption) | |
| >>> sample = {"image": image, "text_input": [text_input]} | |
| >>> features_multimodal = model.extract_features(sample) | |
| >>> features_multimodal.keys() | |
| odict_keys(['image_embeds', 'multimodal_embeds']) | |
| >>> features_multimodal.image_embeds.shape | |
| torch.Size([1, 197, 768]) | |
| >>> features_multimodal.multimodal_embeds.shape | |
| torch.Size([1, 12, 768]) | |
| >>> features_text = model.extract_features(sample, mode="text") | |
| >>> features_text.keys() | |
| odict_keys(['text_embeds', 'text_features']) | |
| >>> features_text.text_embeds.shape | |
| torch.Size([1, 12, 768]) | |
| >>> features_text.text_features.shape | |
| torch.Size([1, 12, 256]) | |
| >>> features_image = model.extract_features(sample, mode="image") | |
| >>> features_image.keys() | |
| odict_keys(['image_embeds', 'image_features']) | |
| >>> features_image.image_embeds.shape | |
| torch.Size([1, 197, 768]) | |
| >>> features_image.image_features.shape | |
| torch.Size([1, 197, 256]) | |
| ``` | |
| """ | |
| image = samples["image"] | |
| caption = samples["text_input"] | |
| if isinstance(mode, str): | |
| mode = [mode] | |
| for m in mode: | |
| assert m in [ | |
| "multimodal", | |
| "image", | |
| "text", | |
| ], "mode must be one of [multimodal, image, text], but got {}".format(m) | |
| # initalize output | |
| image_embeds, text_embeds, multimodal_embeds = None, None, None | |
| image_features, text_features = None, None | |
| if "image" in mode or "multimodal" in mode: | |
| assert ( | |
| image is not None | |
| ), "image must be provided if mode is 'image' or 'multimodal'" | |
| image_embeds = self.visual_encoder.forward_features(image) | |
| image_features = F.normalize(self.vision_proj(image_embeds), dim=-1) | |
| if "text" in mode or "multimodal" in mode: | |
| assert ( | |
| caption is not None | |
| ), "text must be provided if mode is 'text' or 'multimodal'" | |
| text = self.tokenizer( | |
| caption, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| text_output = self.text_encoder.bert( | |
| text.input_ids, | |
| attention_mask=text.attention_mask, | |
| return_dict=True, | |
| mode="text", | |
| ) | |
| text_embeds = text_output.last_hidden_state | |
| text_features = F.normalize(self.text_proj(text_embeds), dim=-1) | |
| if "multimodal" in mode: | |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
| self.device | |
| ) | |
| # forward the positve image-text pair | |
| output = self.text_encoder.bert( | |
| encoder_embeds=text_embeds, | |
| attention_mask=text.attention_mask, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_atts, | |
| return_dict=True, | |
| mode="fusion", | |
| ) | |
| multimodal_embeds = output.last_hidden_state | |
| return AlbefOutputFeatures( | |
| image_embeds=image_embeds, | |
| image_embeds_proj=image_features, | |
| text_embeds=text_embeds, | |
| text_embeds_proj=text_features, | |
| multimodal_embeds=multimodal_embeds, | |
| ) | |
| def from_config(cls, cfg=None): | |
| image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True) | |
| config_text_encoder = BertConfig.from_json_file( | |
| get_abs_path(cfg["med_config_path"]) | |
| ) | |
| config_text_encoder.fusion_layer = 6 | |
| text_encoder = BertForMaskedLM.from_pretrained( | |
| "bert-base-uncased", config=config_text_encoder | |
| ) | |
| embed_dim = cfg.get("embed_dim", 256) | |
| max_txt_len = cfg.get("max_txt_len", 30) | |
| model = cls( | |
| image_encoder=image_encoder, | |
| text_encoder=text_encoder, | |
| embed_dim=embed_dim, | |
| max_txt_len=max_txt_len, | |
| ) | |
| # load pre-trained weights | |
| pretrain_path = cfg.get("pretrained", None) | |
| if pretrain_path is not None: | |
| msg = model.load_from_pretrained( | |
| url_or_filename=pretrain_path, rename_text_keys=False | |
| ) | |
| else: | |
| warnings.warn("No pretrained weights are loaded.") | |
| return model | |