Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import torch | |
from lavis.common.registry import registry | |
from lavis.models.blip_models.blip import BlipBase | |
from lavis.models.blip_models.blip_outputs import ( | |
BlipOutput, | |
BlipIntermediateOutput, | |
) | |
from lavis.models.med import XBertLMHeadDecoder | |
from lavis.models.vit import VisionTransformerEncoder | |
class BlipCaption(BlipBase): | |
""" | |
BLIP captioning model. | |
Supported model types: | |
- base_coco: fine-tuned BLIP base model on COCO caption dataset (Karparthy split). | |
- large_coco: fine-tuned BLIP large model on COCO caption dataset (Karparthy split). | |
Usage: | |
>>> from lavis.models import load_model | |
>>> model = load_model("blip_caption", "base_coco") | |
>>> model = load_model("blip_caption", "large_coco") | |
""" | |
PRETRAINED_MODEL_CONFIG_DICT = { | |
"base_coco": "configs/models/blip_caption_base_coco.yaml", | |
"large_coco": "configs/models/blip_caption_large_coco.yaml", | |
} | |
def __init__(self, image_encoder, text_decoder, prompt=None, max_txt_len=40): | |
super().__init__() | |
self.tokenizer = self.init_tokenizer() | |
self.visual_encoder = image_encoder | |
self.text_decoder = text_decoder | |
self.prompt = prompt | |
self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 | |
self.max_txt_len = max_txt_len | |
def forward_encoder(self, samples): | |
image_embeds = self.visual_encoder.forward_features(samples["image"]) | |
return image_embeds | |
def forward_decoder(self, samples, image_embeds): | |
# prepare inputs for forwarding decoder | |
raw_text = samples["text_input"] | |
text = self.tokenizer( | |
raw_text, | |
padding="longest", | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
text.input_ids[:, 0] = self.tokenizer.bos_token_id | |
# prepare targets for forwarding decoder | |
decoder_targets = text.input_ids.masked_fill( | |
text.input_ids == self.tokenizer.pad_token_id, -100 | |
) | |
decoder_targets[:, : self.prompt_length] = -100 | |
# forward decoder | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
self.device | |
) | |
decoder_output = self.text_decoder( | |
input_ids=text.input_ids, | |
attention_mask=text.attention_mask, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
labels=decoder_targets, | |
return_dict=True, | |
) | |
return decoder_output, decoder_targets | |
def forward(self, samples): | |
r""" | |
Args: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
- text_input (list): A list of strings of length batch_size. | |
Returns: | |
output (BlipOutput): A BlipOutput object containing the following | |
attributes: | |
- loss (torch.Tensor): A scalar tensor containing the total loss. For BlipCaption, this is the same as the LM loss. | |
- loss_lm (torch.Tensor): A scalar tensor containing the LM loss. | |
- intermediate_outputs (BlipIntermediateOutput): A BlipIntermediateOutput object containing intermediate outputs. | |
see :class:`lavis.models.blip_models.blip_outputs.BlipOutput` for more details. | |
Example: | |
```python | |
>>> from PIL import Image | |
>>> from lavis.models import load_model_and_preprocess | |
>>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption") | |
>>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") | |
>>> image = vis_processors["eval"](raw_image).unsqueeze(0) | |
>>> text_input = ["a large statue of a person spraying water from a fountain"] | |
>>> samples = {"image": image, "text_input": text_input} | |
>>> output = model(samples) | |
>>> output.keys() | |
odict_keys(['intermediate_output', 'loss', 'loss_lm']) | |
>>> output.intermediate_output.image_embeds.shape | |
torch.Size([1, 577, 768]) | |
>>> output.intermediate_output.decoder_labels.shape | |
torch.Size([1, 13]) | |
```""" | |
image_embeds = self.forward_encoder(samples) | |
decoder_output, decoder_targets = self.forward_decoder(samples, image_embeds) | |
# return decoder_out | |
return BlipOutput( | |
loss=decoder_output.loss, | |
loss_lm=decoder_output.loss, | |
intermediate_output=BlipIntermediateOutput( | |
image_embeds=image_embeds, | |
decoder_output=decoder_output, | |
decoder_labels=decoder_targets, | |
), | |
) | |
def generate( | |
self, | |
samples, | |
use_nucleus_sampling=False, | |
num_beams=3, | |
max_length=30, | |
min_length=10, | |
top_p=0.9, | |
repetition_penalty=1.0, | |
num_captions=1, | |
): | |
""" | |
Args: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. | |
num_beams (int): Number of beams for beam search. 1 means no beam search. | |
max_length (int): The maximum length of the sequence to be generated. | |
min_length (int): The minimum length of the sequence to be generated. | |
top_p (float): The cumulative probability for nucleus sampling. | |
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. | |
num_captions (int): Number of captions to be generated for each image. | |
Returns: | |
captions (list): A list of strings of length batch_size * num_captions. | |
Example: | |
```python | |
>>> from PIL import Image | |
>>> from lavis.models import load_model_and_preprocess | |
>>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption") | |
>>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") | |
>>> image = vis_processors["eval"](raw_image).unsqueeze(0) | |
>>> samples = {"image": image} | |
>>> captions = model.generate(samples) | |
>>> captions | |
['a large statue of a person spraying water from a fountain'] | |
>>> captions = model.generate(samples, use_nucleus_sampling=True, num_captions=3) | |
>>> captions # example output, results may vary due to randomness | |
['singapore showing the view of some building', | |
'the singapore harbor in twilight, as the weather is going down', | |
'the famous singapore fountain at sunset'] | |
""" | |
# prepare inputs for decoder generation. | |
encoder_out = self.forward_encoder(samples) | |
image_embeds = torch.repeat_interleave(encoder_out, num_captions, 0) | |
prompt = [self.prompt] * image_embeds.size(0) | |
prompt = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
prompt.input_ids[:, 0] = self.tokenizer.bos_token_id | |
prompt.input_ids = prompt.input_ids[:, :-1] | |
# get decoded text | |
decoder_out = self.text_decoder.generate_from_encoder( | |
tokenized_prompt=prompt, | |
visual_embeds=image_embeds, | |
sep_token_id=self.tokenizer.sep_token_id, | |
pad_token_id=self.tokenizer.pad_token_id, | |
use_nucleus_sampling=use_nucleus_sampling, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
) | |
outputs = self.tokenizer.batch_decode(decoder_out, skip_special_tokens=True) | |
captions = [output[len(self.prompt) :] for output in outputs] | |
return captions | |
def from_config(cls, cfg): | |
# vision encoder | |
image_encoder = VisionTransformerEncoder.from_config(cfg) | |
# text encoder + multimodal decoder | |
text_decoder = XBertLMHeadDecoder.from_config(cfg) | |
prompt = cfg.get("prompt", None) | |
max_txt_len = cfg.get("max_txt_len", 40) | |
model = cls(image_encoder, text_decoder, prompt=prompt, max_txt_len=max_txt_len) | |
model.load_checkpoint_from_config(cfg) | |
return model | |