explain-LXMERT / generic.py
WwYc's picture
Create generic.py
5a95ff9 verified
raw
history blame
5 kB
import pylab
from lxmert.src.modeling_frcnn import GeneralizedRCNN
import lxmert.src.vqa_utils as utils
from lxmert.src.processing_image import Preprocess
from transformers import LxmertTokenizer
from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
from tqdm import tqdm
from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
import random
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from captum.attr import visualization
import requests
OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt"
ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt"
VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json"
class ModelUsage:
def __init__(self, use_lrp=False):
self.vqa_answers = utils.get_data(VQA_URL)
# load models and model components
self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned")
self.frcnn_cfg.MODEL.DEVICE = "cpu"
self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
self.image_preprocess = Preprocess(self.frcnn_cfg)
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased")
if use_lrp:
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
else:
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
self.lxmert_vqa.eval()
self.model = self.lxmert_vqa
# self.vqa_dataset = vqa_data.VQADataset(splits="valid")
def forward(self, item):
URL, question = item
self.image_file_path = URL
# run frcnn
images, sizes, scales_yx = self.image_preprocess(URL)
output_dict = self.frcnn(
images,
sizes,
scales_yx=scales_yx,
padding="max_detections",
max_detections=self.frcnn_cfg.max_detections,
return_tensors="pt"
)
inputs = self.lxmert_tokenizer(
question,
truncation=True,
return_token_type_ids=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
self.text_len = len(self.question_tokens)
# Very important that the boxes are normalized
normalized_boxes = output_dict.get("normalized_boxes")
features = output_dict.get("roi_features")
self.image_boxes_len = features.shape[1]
self.bboxes = output_dict.get("boxes")
self.output = self.lxmert_vqa(
input_ids=inputs.input_ids),
attention_mask=inputs.attention_mask,
visual_feats=features,
visual_pos=normalized_boxes,
token_type_ids=inputs.token_type_ids,
return_dict=True,
output_attentions=False,
)
return self.output
def save_image_vis(image_file_path, bbox_scores):
bbox_scores = image_scores
_, top_bboxes_indices = bbox_scores.topk(k=1, dim=-1)
img = cv2.imread(image_file_path)
mask = torch.zeros(img.shape[0], img.shape[1])
for index in range(len(bbox_scores)):
[x, y, w, h] = model_lrp.bboxes[0][index]
curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
new_score_tensor = torch.ones_like(curr_score_tensor) * bbox_scores[index].item()
mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)])
mask = (mask - mask.min()) / (mask.max() - mask.min())
mask = mask.unsqueeze_(-1)
mask = mask.expand(img.shape)
img = img * mask.cpu().data.numpy()
cv2.imwrite(
'lxmert/lxmert/experiments/paper/new.jpg', img)
return img
model_lrp = ModelUsage(use_lrp=True)
lrp = GeneratorOurs(model_lrp)
baselines = GeneratorBaselines(model_lrp)
vqa_answers = utils.get_data(VQA_URL)
image_ids = [
# giraffe
'COCO_val2014_000000185590',
# baseball
'COCO_val2014_000000127510',
# bath
'COCO_val2014_000000324266',
# frisbee
'COCO_val2014_000000200717'
]
test_questions_for_images = [
################## paper samples
# giraffe
"is the animal eating?",
# baseball
"did he catch the ball?",
# bath
"is the tub white ?",
# frisbee
"did the man just catch the frisbee?"
################## paper samples
]