Spaces:
Sleeping
Sleeping
import pylab | |
from lxmert.src.modeling_frcnn import GeneralizedRCNN | |
import lxmert.src.vqa_utils as utils | |
from lxmert.src.processing_image import Preprocess | |
from transformers import LxmertTokenizer | |
from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering | |
from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP | |
from tqdm import tqdm | |
from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation | |
import random | |
import numpy as np | |
import cv2 | |
import torch | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from captum.attr import visualization | |
import requests | |
OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt" | |
ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt" | |
VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json" | |
class ModelUsage: | |
def __init__(self, use_lrp=False): | |
self.vqa_answers = utils.get_data(VQA_URL) | |
# load models and model components | |
self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned") | |
self.frcnn_cfg.MODEL.DEVICE = "cpu" | |
self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg) | |
self.image_preprocess = Preprocess(self.frcnn_cfg) | |
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased") | |
if use_lrp: | |
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
else: | |
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
self.lxmert_vqa.eval() | |
self.model = self.lxmert_vqa | |
# self.vqa_dataset = vqa_data.VQADataset(splits="valid") | |
def forward(self, item): | |
URL, question = item | |
self.image_file_path = URL | |
# run frcnn | |
images, sizes, scales_yx = self.image_preprocess(URL) | |
output_dict = self.frcnn( | |
images, | |
sizes, | |
scales_yx=scales_yx, | |
padding="max_detections", | |
max_detections=self.frcnn_cfg.max_detections, | |
return_tensors="pt" | |
) | |
inputs = self.lxmert_tokenizer( | |
question, | |
truncation=True, | |
return_token_type_ids=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt" | |
) | |
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten()) | |
self.text_len = len(self.question_tokens) | |
# Very important that the boxes are normalized | |
normalized_boxes = output_dict.get("normalized_boxes") | |
features = output_dict.get("roi_features") | |
self.image_boxes_len = features.shape[1] | |
self.bboxes = output_dict.get("boxes") | |
self.output = self.lxmert_vqa( | |
input_ids=inputs.input_ids), | |
attention_mask=inputs.attention_mask, | |
visual_feats=features, | |
visual_pos=normalized_boxes, | |
token_type_ids=inputs.token_type_ids, | |
return_dict=True, | |
output_attentions=False, | |
) | |
return self.output | |
def save_image_vis(image_file_path, bbox_scores): | |
bbox_scores = image_scores | |
_, top_bboxes_indices = bbox_scores.topk(k=1, dim=-1) | |
img = cv2.imread(image_file_path) | |
mask = torch.zeros(img.shape[0], img.shape[1]) | |
for index in range(len(bbox_scores)): | |
[x, y, w, h] = model_lrp.bboxes[0][index] | |
curr_score_tensor = mask[int(y):int(h), int(x):int(w)] | |
new_score_tensor = torch.ones_like(curr_score_tensor) * bbox_scores[index].item() | |
mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)]) | |
mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
mask = mask.unsqueeze_(-1) | |
mask = mask.expand(img.shape) | |
img = img * mask.cpu().data.numpy() | |
cv2.imwrite( | |
'lxmert/lxmert/experiments/paper/new.jpg', img) | |
return img | |
model_lrp = ModelUsage(use_lrp=True) | |
lrp = GeneratorOurs(model_lrp) | |
baselines = GeneratorBaselines(model_lrp) | |
vqa_answers = utils.get_data(VQA_URL) | |
image_ids = [ | |
# giraffe | |
'COCO_val2014_000000185590', | |
# baseball | |
'COCO_val2014_000000127510', | |
# bath | |
'COCO_val2014_000000324266', | |
# frisbee | |
'COCO_val2014_000000200717' | |
] | |
test_questions_for_images = [ | |
################## paper samples | |
# giraffe | |
"is the animal eating?", | |
# baseball | |
"did he catch the ball?", | |
# bath | |
"is the tub white ?", | |
# frisbee | |
"did the man just catch the frisbee?" | |
################## paper samples | |
] | |