jchwenger's picture
Upload 351 files (#2)
d9272c6 verified
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
from lavis.common.registry import registry
from lavis.models.blip_models.blip import BlipBase
from lavis.models.blip_models.blip_outputs import (
BlipOutput,
BlipIntermediateOutput,
)
from lavis.models.med import XBertLMHeadDecoder
from lavis.models.vit import VisionTransformerEncoder
@registry.register_model("blip_caption")
class BlipCaption(BlipBase):
"""
BLIP captioning model.
Supported model types:
- base_coco: fine-tuned BLIP base model on COCO caption dataset (Karparthy split).
- large_coco: fine-tuned BLIP large model on COCO caption dataset (Karparthy split).
Usage:
>>> from lavis.models import load_model
>>> model = load_model("blip_caption", "base_coco")
>>> model = load_model("blip_caption", "large_coco")
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"base_coco": "configs/models/blip_caption_base_coco.yaml",
"large_coco": "configs/models/blip_caption_large_coco.yaml",
}
def __init__(self, image_encoder, text_decoder, prompt=None, max_txt_len=40):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.visual_encoder = image_encoder
self.text_decoder = text_decoder
self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
self.max_txt_len = max_txt_len
def forward_encoder(self, samples):
image_embeds = self.visual_encoder.forward_features(samples["image"])
return image_embeds
def forward_decoder(self, samples, image_embeds):
# prepare inputs for forwarding decoder
raw_text = samples["text_input"]
text = self.tokenizer(
raw_text,
padding="longest",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(self.device)
text.input_ids[:, 0] = self.tokenizer.bos_token_id
# prepare targets for forwarding decoder
decoder_targets = text.input_ids.masked_fill(
text.input_ids == self.tokenizer.pad_token_id, -100
)
decoder_targets[:, : self.prompt_length] = -100
# forward decoder
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
self.device
)
decoder_output = self.text_decoder(
input_ids=text.input_ids,
attention_mask=text.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
labels=decoder_targets,
return_dict=True,
)
return decoder_output, decoder_targets
def forward(self, samples):
r"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
- text_input (list): A list of strings of length batch_size.
Returns:
output (BlipOutput): A BlipOutput object containing the following
attributes:
- loss (torch.Tensor): A scalar tensor containing the total loss. For BlipCaption, this is the same as the LM loss.
- loss_lm (torch.Tensor): A scalar tensor containing the LM loss.
- intermediate_outputs (BlipIntermediateOutput): A BlipIntermediateOutput object containing intermediate outputs.
see :class:`lavis.models.blip_models.blip_outputs.BlipOutput` for more details.
Example:
```python
>>> from PIL import Image
>>> from lavis.models import load_model_and_preprocess
>>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption")
>>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
>>> image = vis_processors["eval"](raw_image).unsqueeze(0)
>>> text_input = ["a large statue of a person spraying water from a fountain"]
>>> samples = {"image": image, "text_input": text_input}
>>> output = model(samples)
>>> output.keys()
odict_keys(['intermediate_output', 'loss', 'loss_lm'])
>>> output.intermediate_output.image_embeds.shape
torch.Size([1, 577, 768])
>>> output.intermediate_output.decoder_labels.shape
torch.Size([1, 13])
```"""
image_embeds = self.forward_encoder(samples)
decoder_output, decoder_targets = self.forward_decoder(samples, image_embeds)
# return decoder_out
return BlipOutput(
loss=decoder_output.loss,
loss_lm=decoder_output.loss,
intermediate_output=BlipIntermediateOutput(
image_embeds=image_embeds,
decoder_output=decoder_output,
decoder_labels=decoder_targets,
),
)
def generate(
self,
samples,
use_nucleus_sampling=False,
num_beams=3,
max_length=30,
min_length=10,
top_p=0.9,
repetition_penalty=1.0,
num_captions=1,
):
"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
num_beams (int): Number of beams for beam search. 1 means no beam search.
max_length (int): The maximum length of the sequence to be generated.
min_length (int): The minimum length of the sequence to be generated.
top_p (float): The cumulative probability for nucleus sampling.
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
num_captions (int): Number of captions to be generated for each image.
Returns:
captions (list): A list of strings of length batch_size * num_captions.
Example:
```python
>>> from PIL import Image
>>> from lavis.models import load_model_and_preprocess
>>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption")
>>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
>>> image = vis_processors["eval"](raw_image).unsqueeze(0)
>>> samples = {"image": image}
>>> captions = model.generate(samples)
>>> captions
['a large statue of a person spraying water from a fountain']
>>> captions = model.generate(samples, use_nucleus_sampling=True, num_captions=3)
>>> captions # example output, results may vary due to randomness
['singapore showing the view of some building',
'the singapore harbor in twilight, as the weather is going down',
'the famous singapore fountain at sunset']
"""
# prepare inputs for decoder generation.
encoder_out = self.forward_encoder(samples)
image_embeds = torch.repeat_interleave(encoder_out, num_captions, 0)
prompt = [self.prompt] * image_embeds.size(0)
prompt = self.tokenizer(prompt, return_tensors="pt").to(self.device)
prompt.input_ids[:, 0] = self.tokenizer.bos_token_id
prompt.input_ids = prompt.input_ids[:, :-1]
# get decoded text
decoder_out = self.text_decoder.generate_from_encoder(
tokenized_prompt=prompt,
visual_embeds=image_embeds,
sep_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_nucleus_sampling=use_nucleus_sampling,
num_beams=num_beams,
max_length=max_length,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
outputs = self.tokenizer.batch_decode(decoder_out, skip_special_tokens=True)
captions = [output[len(self.prompt) :] for output in outputs]
return captions
@classmethod
def from_config(cls, cfg):
# vision encoder
image_encoder = VisionTransformerEncoder.from_config(cfg)
# text encoder + multimodal decoder
text_decoder = XBertLMHeadDecoder.from_config(cfg)
prompt = cfg.get("prompt", None)
max_txt_len = cfg.get("max_txt_len", 40)
model = cls(image_encoder, text_decoder, prompt=prompt, max_txt_len=max_txt_len)
model.load_checkpoint_from_config(cfg)
return model