John6666's picture
Upload 351 files
e84842d verified
raw
history blame
16.5 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torch.nn as nn
from itertools import chain
from lavis.common.registry import registry
from lavis.models.base_model import BaseModel
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import T5ForConditionalGeneration
from lavis.models.pnp_vqa_models import prepare_qa_input
from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
@registry.register_model("pnp_vqa")
class PNPVQA(BaseModel):
"""
PNPVQA model consists of three submodels for zero-shot VQA:
1. Image-questioning matching model
2. Image captioning model
3. Question answering model
Supported model types:
- base: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-base)
- large: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-large)
- 3b: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-3b)
Usage:
>>> from lavis.models import load_model
>>> model = load_model("pnp_vqa", "base", is_eval=True)
>>> model = load_model("pnp_vqa", "large", is_eval=True)
>>> model = load_model("pnp_vqa", "3b", is_eval=True)
"""
PRETRAINED_MODEL_CONFIG_DICT = {"base": "configs/models/pnp-vqa/pnp_vqa_base.yaml",
"large": "configs/models/pnp-vqa/pnp_vqa_large.yaml",
"3b": "configs/models/pnp-vqa/pnp_vqa_3b.yaml",
}
def __init__(self, image_question_matching_model, image_captioning_model,
question_answering_model, offload_model=False):
super().__init__()
self.image_question_matching_model = image_question_matching_model
self.image_captioning_model = image_captioning_model
self.question_answering_model = question_answering_model
self.offload_model = offload_model
def forward_itm(self, samples, block_num=7):
"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
- text_input (list): A list of strings of length batch_size
block_num (int): The index of cross-attention block for gradcam computation.
Returns:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
- text_input (list): A list of strings of length batch_size
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)
"""
image = samples['image']
question = [text.strip('?') for text in samples['text_input']]
tokenized_text = self.image_question_matching_model.tokenizer(question, padding='longest', truncation=True,
return_tensors="pt").to(self.image_question_matching_model.device)
with torch.set_grad_enabled(True):
gradcams, _ = compute_gradcam(model=self.image_question_matching_model,
visual_input=image,
text_input=question,
tokenized_text=tokenized_text,
block_num=block_num)
gradcams = [gradcam_[1] for gradcam_ in gradcams]
samples['gradcams'] = torch.stack(gradcams).reshape(samples['image'].size(0), -1)
return samples
def forward_cap(
self,
samples,
cap_max_length=20,
cap_min_length=0,
top_p=1,
top_k=50,
repetition_penalty=1.0,
num_captions=100,
num_patches=20,
):
"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
- text_input (list): A list of strings of length batch_size
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)
cap_max_length (int): The maximum length of the caption to be generated.
cap_min_length (int): The minimum length of the caption to be generated.
top_p (float): The cumulative probability for nucleus sampling.
top_k (float): The number of the highest probability tokens for top-k sampling.
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
num_captions (int): Number of captions generated for each image.
num_patches (int): Number of patches sampled for each image.
Returns:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
- text_input (list): A list of strings of length batch_size
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)
- captions (nested list): A nested list of strings of total length batch_size * num_captions
"""
encoder_out = self.image_captioning_model.forward_encoder(samples)
captions = [[] for _ in range(encoder_out.size(0))]
min_num_captions = 0
while min_num_captions < num_captions:
encoder_out_samples = []
for i in range(num_captions):
patch_id = torch.multinomial(samples['gradcams'].to(self.image_captioning_model.device),
num_patches).reshape(encoder_out.size(0), -1) + 1
patch_id = patch_id.sort(dim=1).values.unsqueeze(-1).expand(-1, -1, encoder_out.size(2))
encoder_out_sample = torch.gather(encoder_out, 1, patch_id)
encoder_out_samples.append(encoder_out_sample)
stacked = torch.stack(encoder_out_samples, dim=1)
image_embeds = torch.flatten(stacked, start_dim=0, end_dim=1) #(bsz*num_seq, num_patch, dim)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.image_captioning_model.device)
model_kwargs = {
"encoder_hidden_states": image_embeds,
"encoder_attention_mask": image_atts,
}
prompt = [self.image_captioning_model.prompt] * image_embeds.size(0)
prompt = self.image_captioning_model.tokenizer(prompt,
return_tensors="pt").to(self.image_captioning_model.device)
prompt.input_ids[:, 0] = self.image_captioning_model.tokenizer.bos_token_id
prompt.input_ids = prompt.input_ids[:, :-1]
decoder_out = self.image_captioning_model.text_decoder.generate(
input_ids=prompt.input_ids,
max_length=cap_max_length,
min_length=cap_min_length,
do_sample=True,
top_p=top_p,
top_k=top_k,
num_return_sequences=1,
eos_token_id=self.image_captioning_model.tokenizer.sep_token_id,
pad_token_id=self.image_captioning_model.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs)
outputs = self.image_captioning_model.tokenizer.batch_decode(decoder_out, skip_special_tokens=True)
for counter, output in enumerate(outputs):
ind = counter//num_captions
if len(captions[ind]) < num_captions:
caption = output[len(self.image_captioning_model.prompt):]
overlap_caption = [1 for caps in captions[ind] if caption in caps]
if len(overlap_caption) == 0:
captions[ind].append(caption)
min_num_captions = min([len(i) for i in captions])
samples['captions'] = captions
return samples
def forward_qa(
self,
samples,
num_beams=1,
max_len=20,
min_len=0,
internal_bsz_fid=1,
num_captions=100,
num_captions_fid=1,
):
"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
- text_input (list): A list of strings of length batch_size
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)
- captions (nested list): A nested list of strings of total length batch_size * num_captions
- question_captions (nested list): A nested list of concatenated strings of questions and captions
num_beams (int): Number of beams for beam search. 1 means no beam search.
max_len (int): Maximum length of generated answers.
min_len (int): Minimum length of generated answers.
internal_bsz_fid (int): Internal batch size when using FiD decoding.
num_captions (int): Number of captions generated for each image.
num_captions_fid (int): Number of captions concatenated with a question during FiD decoding.
Returns:
List: A list of strings, each string is an answer.
"""
prepare_qa_input(samples, num_captions=num_captions, num_captions_fid=num_captions_fid)
pred_answers = []
question_captions = samples['question_captions']
question_captions_chunk = [question_captions[i:i + internal_bsz_fid]
for i in range(0, len(question_captions), internal_bsz_fid)]
question_captions_chunk = list(chain(*question_captions_chunk))
for question_caption in question_captions_chunk:
question_caption_input = self.question_answering_model.tokenizer(question_caption, padding='longest',
truncation=True, return_tensors="pt").to(self.question_answering_model.device)
question_caption_input.input_ids = question_caption_input.input_ids.reshape(
internal_bsz_fid, -1, question_caption_input.input_ids.size(1))
question_caption_input.attention_mask = question_caption_input.attention_mask.reshape(
internal_bsz_fid, -1, question_caption_input.attention_mask.size(1))
outputs = self.question_answering_model.generate(input_ids=question_caption_input.input_ids,
attention_mask=question_caption_input.attention_mask,
num_beams=num_beams,
min_length=min_len,
max_length=max_len,
)
for output in outputs:
pred_answer = self.question_answering_model.tokenizer.decode(output, skip_special_tokens=True)
pred_answers.append(pred_answer)
return pred_answers
def predict_answers(
self,
samples,
num_beams=1,
inference_method="generate",
max_len=20,
min_len=0,
internal_bsz_fid=1,
num_captions=50,
num_captions_fid=1,
cap_max_length=20,
cap_min_length=10,
top_k=50,
top_p=1,
repetition_penalty=1,
num_patches=50,
block_num=7,
):
"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480.
- text_input (str or [str]): String or a list of strings, each string is a question.
The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first.
num_beams (int): Number of beams for beam search. 1 means no beam search.
inference_method (str): Inference method. Must be "generate". The model will generate answers.
max_len (int): Maximum length of generated answers.
min_len (int): Minimum length of generated answers.
internal_bsz_fid (int): Internal batch size when using FiD decoding.
num_captions (int): Number of captions generated for each image.
num_captions_fid (int): Number of captions concatenated with a question during FiD decoding.
cap_max_length (int): The maximum length of the caption to be generated.
cap_min_length (int): The minimum length of the caption to be generated.
top_k (float): The number of the highest probability tokens for top-k sampling.
top_p (float): The cumulative probability for nucleus sampling.
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
num_patches (int): Number of patches sampled for each image.
block_num (int): The index of cross-attention block for gradcam computation.
Returns:
List: A list of strings, each string is an answer.
gradcams (torch.Tensor): A tensor of shape (batch_size, H*W)
captions (nested list): A nested list of strings of total length batch_size * num_captions
"""
assert inference_method in [
"generate",
], "Inference method must be 'generate', got {}.".format(
inference_method
)
if isinstance(samples["text_input"], str):
samples["text_input"] = [samples["text_input"]]
assert len(samples["text_input"]) == samples["image"].size(
0
), "The number of questions must be equal to the batch size."
samples = self.forward_itm(samples, block_num=block_num)
samples = self.forward_cap(samples,
cap_max_length=cap_max_length,
cap_min_length=cap_min_length,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_captions=num_captions,
num_patches=num_patches)
if self.offload_model:
samples['image'] = samples['image'].to('cpu')
self.image_question_matching_model.to('cpu')
self.image_captioning_model.to('cpu')
torch.cuda.empty_cache()
pred_answers = self.forward_qa(samples,
num_beams=num_beams,
max_len=max_len,
min_len=min_len,
internal_bsz_fid=internal_bsz_fid,
num_captions=num_captions,
num_captions_fid=num_captions_fid)
if self.offload_model:
self.image_question_matching_model.to(self.question_answering_model.device)
self.image_captioning_model.to(self.question_answering_model.device)
return pred_answers, samples['captions'], samples['gradcams']
@classmethod
def from_config(cls, model_config):
itm_config = model_config.image_question_matching_model
cap_config = model_config.image_captioning_model
qa_config = model_config.question_answering_model
itm_cls = registry.get_model_class(itm_config.arch)
cap_cls = registry.get_model_class(cap_config.arch)
qa_cls = registry.get_model_class(qa_config.arch)
image_question_matching_model = itm_cls.from_config(itm_config)
image_captioning_model = cap_cls.from_config(cap_config)
question_answering_model = qa_cls.from_config(qa_config)
model = cls(image_question_matching_model=image_question_matching_model,
image_captioning_model=image_captioning_model,
question_answering_model=question_answering_model,
offload_model= True if model_config.model_type == '3b' else False,
)
return model