Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import torch | |
import torch.nn as nn | |
from itertools import chain | |
from lavis.common.registry import registry | |
from lavis.models.base_model import BaseModel | |
from torch.nn import CrossEntropyLoss, MSELoss | |
from transformers import T5ForConditionalGeneration | |
from lavis.models.pnp_vqa_models import prepare_qa_input | |
from lavis.models.blip_models.blip_image_text_matching import compute_gradcam | |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
class PNPVQA(BaseModel): | |
""" | |
PNPVQA model consists of three submodels for zero-shot VQA: | |
1. Image-questioning matching model | |
2. Image captioning model | |
3. Question answering model | |
Supported model types: | |
- base: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-base) | |
- large: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-large) | |
- 3b: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-3b) | |
Usage: | |
>>> from lavis.models import load_model | |
>>> model = load_model("pnp_vqa", "base", is_eval=True) | |
>>> model = load_model("pnp_vqa", "large", is_eval=True) | |
>>> model = load_model("pnp_vqa", "3b", is_eval=True) | |
""" | |
PRETRAINED_MODEL_CONFIG_DICT = {"base": "configs/models/pnp-vqa/pnp_vqa_base.yaml", | |
"large": "configs/models/pnp-vqa/pnp_vqa_large.yaml", | |
"3b": "configs/models/pnp-vqa/pnp_vqa_3b.yaml", | |
} | |
def __init__(self, image_question_matching_model, image_captioning_model, | |
question_answering_model, offload_model=False): | |
super().__init__() | |
self.image_question_matching_model = image_question_matching_model | |
self.image_captioning_model = image_captioning_model | |
self.question_answering_model = question_answering_model | |
self.offload_model = offload_model | |
def forward_itm(self, samples, block_num=7): | |
""" | |
Args: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
- text_input (list): A list of strings of length batch_size | |
block_num (int): The index of cross-attention block for gradcam computation. | |
Returns: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
- text_input (list): A list of strings of length batch_size | |
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) | |
""" | |
image = samples['image'] | |
question = [text.strip('?') for text in samples['text_input']] | |
tokenized_text = self.image_question_matching_model.tokenizer(question, padding='longest', truncation=True, | |
return_tensors="pt").to(self.image_question_matching_model.device) | |
with torch.set_grad_enabled(True): | |
gradcams, _ = compute_gradcam(model=self.image_question_matching_model, | |
visual_input=image, | |
text_input=question, | |
tokenized_text=tokenized_text, | |
block_num=block_num) | |
gradcams = [gradcam_[1] for gradcam_ in gradcams] | |
samples['gradcams'] = torch.stack(gradcams).reshape(samples['image'].size(0), -1) | |
return samples | |
def forward_cap( | |
self, | |
samples, | |
cap_max_length=20, | |
cap_min_length=0, | |
top_p=1, | |
top_k=50, | |
repetition_penalty=1.0, | |
num_captions=100, | |
num_patches=20, | |
): | |
""" | |
Args: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
- text_input (list): A list of strings of length batch_size | |
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) | |
cap_max_length (int): The maximum length of the caption to be generated. | |
cap_min_length (int): The minimum length of the caption to be generated. | |
top_p (float): The cumulative probability for nucleus sampling. | |
top_k (float): The number of the highest probability tokens for top-k sampling. | |
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. | |
num_captions (int): Number of captions generated for each image. | |
num_patches (int): Number of patches sampled for each image. | |
Returns: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
- text_input (list): A list of strings of length batch_size | |
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) | |
- captions (nested list): A nested list of strings of total length batch_size * num_captions | |
""" | |
encoder_out = self.image_captioning_model.forward_encoder(samples) | |
captions = [[] for _ in range(encoder_out.size(0))] | |
min_num_captions = 0 | |
while min_num_captions < num_captions: | |
encoder_out_samples = [] | |
for i in range(num_captions): | |
patch_id = torch.multinomial(samples['gradcams'].to(self.image_captioning_model.device), | |
num_patches).reshape(encoder_out.size(0), -1) + 1 | |
patch_id = patch_id.sort(dim=1).values.unsqueeze(-1).expand(-1, -1, encoder_out.size(2)) | |
encoder_out_sample = torch.gather(encoder_out, 1, patch_id) | |
encoder_out_samples.append(encoder_out_sample) | |
stacked = torch.stack(encoder_out_samples, dim=1) | |
image_embeds = torch.flatten(stacked, start_dim=0, end_dim=1) #(bsz*num_seq, num_patch, dim) | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.image_captioning_model.device) | |
model_kwargs = { | |
"encoder_hidden_states": image_embeds, | |
"encoder_attention_mask": image_atts, | |
} | |
prompt = [self.image_captioning_model.prompt] * image_embeds.size(0) | |
prompt = self.image_captioning_model.tokenizer(prompt, | |
return_tensors="pt").to(self.image_captioning_model.device) | |
prompt.input_ids[:, 0] = self.image_captioning_model.tokenizer.bos_token_id | |
prompt.input_ids = prompt.input_ids[:, :-1] | |
decoder_out = self.image_captioning_model.text_decoder.generate( | |
input_ids=prompt.input_ids, | |
max_length=cap_max_length, | |
min_length=cap_min_length, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
num_return_sequences=1, | |
eos_token_id=self.image_captioning_model.tokenizer.sep_token_id, | |
pad_token_id=self.image_captioning_model.tokenizer.pad_token_id, | |
repetition_penalty=repetition_penalty, | |
**model_kwargs) | |
outputs = self.image_captioning_model.tokenizer.batch_decode(decoder_out, skip_special_tokens=True) | |
for counter, output in enumerate(outputs): | |
ind = counter//num_captions | |
if len(captions[ind]) < num_captions: | |
caption = output[len(self.image_captioning_model.prompt):] | |
overlap_caption = [1 for caps in captions[ind] if caption in caps] | |
if len(overlap_caption) == 0: | |
captions[ind].append(caption) | |
min_num_captions = min([len(i) for i in captions]) | |
samples['captions'] = captions | |
return samples | |
def forward_qa( | |
self, | |
samples, | |
num_beams=1, | |
max_len=20, | |
min_len=0, | |
internal_bsz_fid=1, | |
num_captions=100, | |
num_captions_fid=1, | |
): | |
""" | |
Args: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
- text_input (list): A list of strings of length batch_size | |
- gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) | |
- captions (nested list): A nested list of strings of total length batch_size * num_captions | |
- question_captions (nested list): A nested list of concatenated strings of questions and captions | |
num_beams (int): Number of beams for beam search. 1 means no beam search. | |
max_len (int): Maximum length of generated answers. | |
min_len (int): Minimum length of generated answers. | |
internal_bsz_fid (int): Internal batch size when using FiD decoding. | |
num_captions (int): Number of captions generated for each image. | |
num_captions_fid (int): Number of captions concatenated with a question during FiD decoding. | |
Returns: | |
List: A list of strings, each string is an answer. | |
""" | |
prepare_qa_input(samples, num_captions=num_captions, num_captions_fid=num_captions_fid) | |
pred_answers = [] | |
question_captions = samples['question_captions'] | |
question_captions_chunk = [question_captions[i:i + internal_bsz_fid] | |
for i in range(0, len(question_captions), internal_bsz_fid)] | |
question_captions_chunk = list(chain(*question_captions_chunk)) | |
for question_caption in question_captions_chunk: | |
question_caption_input = self.question_answering_model.tokenizer(question_caption, padding='longest', | |
truncation=True, return_tensors="pt").to(self.question_answering_model.device) | |
question_caption_input.input_ids = question_caption_input.input_ids.reshape( | |
internal_bsz_fid, -1, question_caption_input.input_ids.size(1)) | |
question_caption_input.attention_mask = question_caption_input.attention_mask.reshape( | |
internal_bsz_fid, -1, question_caption_input.attention_mask.size(1)) | |
outputs = self.question_answering_model.generate(input_ids=question_caption_input.input_ids, | |
attention_mask=question_caption_input.attention_mask, | |
num_beams=num_beams, | |
min_length=min_len, | |
max_length=max_len, | |
) | |
for output in outputs: | |
pred_answer = self.question_answering_model.tokenizer.decode(output, skip_special_tokens=True) | |
pred_answers.append(pred_answer) | |
return pred_answers | |
def predict_answers( | |
self, | |
samples, | |
num_beams=1, | |
inference_method="generate", | |
max_len=20, | |
min_len=0, | |
internal_bsz_fid=1, | |
num_captions=50, | |
num_captions_fid=1, | |
cap_max_length=20, | |
cap_min_length=10, | |
top_k=50, | |
top_p=1, | |
repetition_penalty=1, | |
num_patches=50, | |
block_num=7, | |
): | |
""" | |
Args: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480. | |
- text_input (str or [str]): String or a list of strings, each string is a question. | |
The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first. | |
num_beams (int): Number of beams for beam search. 1 means no beam search. | |
inference_method (str): Inference method. Must be "generate". The model will generate answers. | |
max_len (int): Maximum length of generated answers. | |
min_len (int): Minimum length of generated answers. | |
internal_bsz_fid (int): Internal batch size when using FiD decoding. | |
num_captions (int): Number of captions generated for each image. | |
num_captions_fid (int): Number of captions concatenated with a question during FiD decoding. | |
cap_max_length (int): The maximum length of the caption to be generated. | |
cap_min_length (int): The minimum length of the caption to be generated. | |
top_k (float): The number of the highest probability tokens for top-k sampling. | |
top_p (float): The cumulative probability for nucleus sampling. | |
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. | |
num_patches (int): Number of patches sampled for each image. | |
block_num (int): The index of cross-attention block for gradcam computation. | |
Returns: | |
List: A list of strings, each string is an answer. | |
gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) | |
captions (nested list): A nested list of strings of total length batch_size * num_captions | |
""" | |
assert inference_method in [ | |
"generate", | |
], "Inference method must be 'generate', got {}.".format( | |
inference_method | |
) | |
if isinstance(samples["text_input"], str): | |
samples["text_input"] = [samples["text_input"]] | |
assert len(samples["text_input"]) == samples["image"].size( | |
0 | |
), "The number of questions must be equal to the batch size." | |
samples = self.forward_itm(samples, block_num=block_num) | |
samples = self.forward_cap(samples, | |
cap_max_length=cap_max_length, | |
cap_min_length=cap_min_length, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
num_captions=num_captions, | |
num_patches=num_patches) | |
if self.offload_model: | |
samples['image'] = samples['image'].to('cpu') | |
self.image_question_matching_model.to('cpu') | |
self.image_captioning_model.to('cpu') | |
torch.cuda.empty_cache() | |
pred_answers = self.forward_qa(samples, | |
num_beams=num_beams, | |
max_len=max_len, | |
min_len=min_len, | |
internal_bsz_fid=internal_bsz_fid, | |
num_captions=num_captions, | |
num_captions_fid=num_captions_fid) | |
if self.offload_model: | |
self.image_question_matching_model.to(self.question_answering_model.device) | |
self.image_captioning_model.to(self.question_answering_model.device) | |
return pred_answers, samples['captions'], samples['gradcams'] | |
def from_config(cls, model_config): | |
itm_config = model_config.image_question_matching_model | |
cap_config = model_config.image_captioning_model | |
qa_config = model_config.question_answering_model | |
itm_cls = registry.get_model_class(itm_config.arch) | |
cap_cls = registry.get_model_class(cap_config.arch) | |
qa_cls = registry.get_model_class(qa_config.arch) | |
image_question_matching_model = itm_cls.from_config(itm_config) | |
image_captioning_model = cap_cls.from_config(cap_config) | |
question_answering_model = qa_cls.from_config(qa_config) | |
model = cls(image_question_matching_model=image_question_matching_model, | |
image_captioning_model=image_captioning_model, | |
question_answering_model=question_answering_model, | |
offload_model= True if model_config.model_type == '3b' else False, | |
) | |
return model |