WwYc commited on
Commit
5a95ff9
·
verified ·
1 Parent(s): d8a5b08

Create generic.py

Browse files
Files changed (1) hide show
  1. generic.py +137 -0
generic.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pylab
2
+ from lxmert.src.modeling_frcnn import GeneralizedRCNN
3
+ import lxmert.src.vqa_utils as utils
4
+ from lxmert.src.processing_image import Preprocess
5
+ from transformers import LxmertTokenizer
6
+ from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
7
+ from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
8
+ from tqdm import tqdm
9
+ from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
10
+ import random
11
+ import numpy as np
12
+ import cv2
13
+ import torch
14
+ import matplotlib.pyplot as plt
15
+ from PIL import Image
16
+ import torchvision.transforms as transforms
17
+ from captum.attr import visualization
18
+ import requests
19
+
20
+ OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt"
21
+ ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt"
22
+ VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json"
23
+
24
+
25
+ class ModelUsage:
26
+ def __init__(self, use_lrp=False):
27
+ self.vqa_answers = utils.get_data(VQA_URL)
28
+
29
+ # load models and model components
30
+ self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned")
31
+ self.frcnn_cfg.MODEL.DEVICE = "cpu"
32
+
33
+ self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
34
+
35
+ self.image_preprocess = Preprocess(self.frcnn_cfg)
36
+
37
+ self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased")
38
+
39
+ if use_lrp:
40
+ self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
41
+ else:
42
+ self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
43
+
44
+ self.lxmert_vqa.eval()
45
+ self.model = self.lxmert_vqa
46
+
47
+ # self.vqa_dataset = vqa_data.VQADataset(splits="valid")
48
+
49
+ def forward(self, item):
50
+ URL, question = item
51
+
52
+ self.image_file_path = URL
53
+
54
+ # run frcnn
55
+ images, sizes, scales_yx = self.image_preprocess(URL)
56
+ output_dict = self.frcnn(
57
+ images,
58
+ sizes,
59
+ scales_yx=scales_yx,
60
+ padding="max_detections",
61
+ max_detections=self.frcnn_cfg.max_detections,
62
+ return_tensors="pt"
63
+ )
64
+ inputs = self.lxmert_tokenizer(
65
+ question,
66
+ truncation=True,
67
+ return_token_type_ids=True,
68
+ return_attention_mask=True,
69
+ add_special_tokens=True,
70
+ return_tensors="pt"
71
+ )
72
+ self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
73
+ self.text_len = len(self.question_tokens)
74
+ # Very important that the boxes are normalized
75
+ normalized_boxes = output_dict.get("normalized_boxes")
76
+ features = output_dict.get("roi_features")
77
+ self.image_boxes_len = features.shape[1]
78
+ self.bboxes = output_dict.get("boxes")
79
+ self.output = self.lxmert_vqa(
80
+ input_ids=inputs.input_ids),
81
+ attention_mask=inputs.attention_mask,
82
+ visual_feats=features,
83
+ visual_pos=normalized_boxes,
84
+ token_type_ids=inputs.token_type_ids,
85
+ return_dict=True,
86
+ output_attentions=False,
87
+ )
88
+ return self.output
89
+
90
+
91
+ def save_image_vis(image_file_path, bbox_scores):
92
+ bbox_scores = image_scores
93
+ _, top_bboxes_indices = bbox_scores.topk(k=1, dim=-1)
94
+ img = cv2.imread(image_file_path)
95
+ mask = torch.zeros(img.shape[0], img.shape[1])
96
+ for index in range(len(bbox_scores)):
97
+ [x, y, w, h] = model_lrp.bboxes[0][index]
98
+ curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
99
+ new_score_tensor = torch.ones_like(curr_score_tensor) * bbox_scores[index].item()
100
+ mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)])
101
+ mask = (mask - mask.min()) / (mask.max() - mask.min())
102
+ mask = mask.unsqueeze_(-1)
103
+ mask = mask.expand(img.shape)
104
+ img = img * mask.cpu().data.numpy()
105
+ cv2.imwrite(
106
+ 'lxmert/lxmert/experiments/paper/new.jpg', img)
107
+ return img
108
+
109
+
110
+ model_lrp = ModelUsage(use_lrp=True)
111
+ lrp = GeneratorOurs(model_lrp)
112
+ baselines = GeneratorBaselines(model_lrp)
113
+ vqa_answers = utils.get_data(VQA_URL)
114
+
115
+ image_ids = [
116
+ # giraffe
117
+ 'COCO_val2014_000000185590',
118
+ # baseball
119
+ 'COCO_val2014_000000127510',
120
+ # bath
121
+ 'COCO_val2014_000000324266',
122
+ # frisbee
123
+ 'COCO_val2014_000000200717'
124
+ ]
125
+
126
+ test_questions_for_images = [
127
+ ################## paper samples
128
+ # giraffe
129
+ "is the animal eating?",
130
+ # baseball
131
+ "did he catch the ball?",
132
+ # bath
133
+ "is the tub white ?",
134
+ # frisbee
135
+ "did the man just catch the frisbee?"
136
+ ################## paper samples
137
+ ]