Spaces:
Sleeping
Sleeping
Delete lxmert/perturbation.py
Browse files- lxmert/perturbation.py +0 -254
lxmert/perturbation.py
DELETED
@@ -1,254 +0,0 @@
|
|
1 |
-
from lxmert.lxmert.src.tasks import vqa_data
|
2 |
-
from lxmert.lxmert.src.modeling_frcnn import GeneralizedRCNN
|
3 |
-
import lxmert.lxmert.src.vqa_utils as utils
|
4 |
-
from lxmert.lxmert.src.processing_image import Preprocess
|
5 |
-
from transformers import LxmertTokenizer
|
6 |
-
from lxmert.lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
|
7 |
-
from lxmert.lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
|
8 |
-
from tqdm import tqdm
|
9 |
-
from lxmert.lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
|
10 |
-
import random
|
11 |
-
from lxmert.lxmert.src.param import args
|
12 |
-
|
13 |
-
OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt"
|
14 |
-
ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt"
|
15 |
-
VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"
|
16 |
-
|
17 |
-
class ModelPert:
|
18 |
-
def __init__(self, COCO_val_path, use_lrp=False):
|
19 |
-
self.COCO_VAL_PATH = COCO_val_path
|
20 |
-
self.vqa_answers = utils.get_data(VQA_URL)
|
21 |
-
|
22 |
-
# load models and model components
|
23 |
-
self.frcnn_cfg = utils.Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
|
24 |
-
self.frcnn_cfg.MODEL.DEVICE = "cuda"
|
25 |
-
|
26 |
-
self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
|
27 |
-
|
28 |
-
self.image_preprocess = Preprocess(self.frcnn_cfg)
|
29 |
-
|
30 |
-
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
|
31 |
-
|
32 |
-
if use_lrp:
|
33 |
-
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
|
34 |
-
else:
|
35 |
-
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
|
36 |
-
|
37 |
-
self.lxmert_vqa.eval()
|
38 |
-
self.model = self.lxmert_vqa
|
39 |
-
|
40 |
-
self.vqa_dataset = vqa_data.VQADataset(splits="valid")
|
41 |
-
|
42 |
-
self.pert_steps = [0, 0.25, 0.5, 0.75, 0.8, 0.85, 0.9, 0.95, 1]
|
43 |
-
self.pert_acc = [0] * len(self.pert_steps)
|
44 |
-
|
45 |
-
def forward(self, item):
|
46 |
-
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
|
47 |
-
self.image_file_path = image_file_path
|
48 |
-
self.image_id = item['img_id']
|
49 |
-
# run frcnn
|
50 |
-
images, sizes, scales_yx = self.image_preprocess(image_file_path)
|
51 |
-
output_dict = self.frcnn(
|
52 |
-
images,
|
53 |
-
sizes,
|
54 |
-
scales_yx=scales_yx,
|
55 |
-
padding="max_detections",
|
56 |
-
max_detections= self.frcnn_cfg.max_detections,
|
57 |
-
return_tensors="pt"
|
58 |
-
)
|
59 |
-
inputs = self.lxmert_tokenizer(
|
60 |
-
item['sent'],
|
61 |
-
truncation=True,
|
62 |
-
return_token_type_ids=True,
|
63 |
-
return_attention_mask=True,
|
64 |
-
add_special_tokens=True,
|
65 |
-
return_tensors="pt"
|
66 |
-
)
|
67 |
-
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
|
68 |
-
self.text_len = len(self.question_tokens)
|
69 |
-
# Very important that the boxes are normalized
|
70 |
-
normalized_boxes = output_dict.get("normalized_boxes")
|
71 |
-
features = output_dict.get("roi_features")
|
72 |
-
self.image_boxes_len = features.shape[1]
|
73 |
-
self.bboxes = output_dict.get("boxes")
|
74 |
-
self.output = self.lxmert_vqa(
|
75 |
-
input_ids=inputs.input_ids.to("cuda"),
|
76 |
-
attention_mask=inputs.attention_mask.to("cuda"),
|
77 |
-
visual_feats=features.to("cuda"),
|
78 |
-
visual_pos=normalized_boxes.to("cuda"),
|
79 |
-
token_type_ids=inputs.token_type_ids.to("cuda"),
|
80 |
-
return_dict=True,
|
81 |
-
output_attentions=False,
|
82 |
-
)
|
83 |
-
return self.output
|
84 |
-
|
85 |
-
def perturbation_image(self, item, cam_image, cam_text, is_positive_pert=False):
|
86 |
-
if is_positive_pert:
|
87 |
-
cam_image = cam_image * (-1)
|
88 |
-
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
|
89 |
-
# run frcnn
|
90 |
-
images, sizes, scales_yx = self.image_preprocess(image_file_path)
|
91 |
-
output_dict = self.frcnn(
|
92 |
-
images,
|
93 |
-
sizes,
|
94 |
-
scales_yx=scales_yx,
|
95 |
-
padding="max_detections",
|
96 |
-
max_detections=self.frcnn_cfg.max_detections,
|
97 |
-
return_tensors="pt"
|
98 |
-
)
|
99 |
-
inputs = self.lxmert_tokenizer(
|
100 |
-
item['sent'],
|
101 |
-
truncation=True,
|
102 |
-
return_token_type_ids=True,
|
103 |
-
return_attention_mask=True,
|
104 |
-
add_special_tokens=True,
|
105 |
-
return_tensors="pt"
|
106 |
-
)
|
107 |
-
# Very important that the boxes are normalized
|
108 |
-
normalized_boxes = output_dict.get("normalized_boxes")
|
109 |
-
features = output_dict.get("roi_features")
|
110 |
-
for step_idx, step in enumerate(self.pert_steps):
|
111 |
-
# find top step boxes
|
112 |
-
curr_num_boxes = int((1 - step) * self.image_boxes_len)
|
113 |
-
_, top_bboxes_indices = cam_image.topk(k=curr_num_boxes, dim=-1)
|
114 |
-
top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
|
115 |
-
|
116 |
-
curr_features = features[:, top_bboxes_indices, :]
|
117 |
-
curr_pos = normalized_boxes[:, top_bboxes_indices, :]
|
118 |
-
|
119 |
-
output = self.lxmert_vqa(
|
120 |
-
input_ids=inputs.input_ids.to("cuda"),
|
121 |
-
attention_mask=inputs.attention_mask.to("cuda"),
|
122 |
-
visual_feats=curr_features.to("cuda"),
|
123 |
-
visual_pos=curr_pos.to("cuda"),
|
124 |
-
token_type_ids=inputs.token_type_ids.to("cuda"),
|
125 |
-
return_dict=True,
|
126 |
-
output_attentions=False,
|
127 |
-
)
|
128 |
-
|
129 |
-
answer = self.vqa_answers[output.question_answering_score.argmax()]
|
130 |
-
accuracy = item["label"].get(answer, 0)
|
131 |
-
self.pert_acc[step_idx] += accuracy
|
132 |
-
|
133 |
-
return self.pert_acc
|
134 |
-
|
135 |
-
def perturbation_text(self, item, cam_image, cam_text, is_positive_pert=False):
|
136 |
-
if is_positive_pert:
|
137 |
-
cam_text = cam_text * (-1)
|
138 |
-
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
|
139 |
-
# run frcnn
|
140 |
-
images, sizes, scales_yx = self.image_preprocess(image_file_path)
|
141 |
-
output_dict = self.frcnn(
|
142 |
-
images,
|
143 |
-
sizes,
|
144 |
-
scales_yx=scales_yx,
|
145 |
-
padding="max_detections",
|
146 |
-
max_detections=self.frcnn_cfg.max_detections,
|
147 |
-
return_tensors="pt"
|
148 |
-
)
|
149 |
-
inputs = self.lxmert_tokenizer(
|
150 |
-
item['sent'],
|
151 |
-
truncation=True,
|
152 |
-
return_token_type_ids=True,
|
153 |
-
return_attention_mask=True,
|
154 |
-
add_special_tokens=True,
|
155 |
-
return_tensors="pt"
|
156 |
-
)
|
157 |
-
# Very important that the boxes are normalized
|
158 |
-
normalized_boxes = output_dict.get("normalized_boxes")
|
159 |
-
features = output_dict.get("roi_features")
|
160 |
-
for step_idx, step in enumerate(self.pert_steps):
|
161 |
-
# we must keep the [CLS] token in order to have the classification
|
162 |
-
# we also keep the [SEP] token
|
163 |
-
cam_pure_text = cam_text[1:-1]
|
164 |
-
text_len = cam_pure_text.shape[0]
|
165 |
-
# find top step tokens, without the [CLS] token and the [SEP] token
|
166 |
-
curr_num_tokens = int((1 - step) * text_len)
|
167 |
-
_, top_bboxes_indices = cam_pure_text.topk(k=curr_num_tokens, dim=-1)
|
168 |
-
top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
|
169 |
-
|
170 |
-
# add back [CLS], [SEP] tokens
|
171 |
-
top_bboxes_indices = [0, cam_text.shape[0] - 1] +\
|
172 |
-
[top_bboxes_indices[i] + 1 for i in range(len(top_bboxes_indices))]
|
173 |
-
# text tokens must be sorted for positional embedding to work
|
174 |
-
top_bboxes_indices = sorted(top_bboxes_indices)
|
175 |
-
|
176 |
-
curr_input_ids = inputs.input_ids[:, top_bboxes_indices]
|
177 |
-
curr_attention_mask = inputs.attention_mask[:, top_bboxes_indices]
|
178 |
-
curr_token_ids = inputs.token_type_ids[:, top_bboxes_indices]
|
179 |
-
|
180 |
-
output = self.lxmert_vqa(
|
181 |
-
input_ids=curr_input_ids.to("cuda"),
|
182 |
-
attention_mask=curr_attention_mask.to("cuda"),
|
183 |
-
visual_feats=features.to("cuda"),
|
184 |
-
visual_pos=normalized_boxes.to("cuda"),
|
185 |
-
token_type_ids=curr_token_ids.to("cuda"),
|
186 |
-
return_dict=True,
|
187 |
-
output_attentions=False,
|
188 |
-
)
|
189 |
-
|
190 |
-
answer = self.vqa_answers[output.question_answering_score.argmax()]
|
191 |
-
accuracy = item["label"].get(answer, 0)
|
192 |
-
self.pert_acc[step_idx] += accuracy
|
193 |
-
|
194 |
-
return self.pert_acc
|
195 |
-
|
196 |
-
def main(args):
|
197 |
-
model_pert = ModelPert(args.COCO_path, use_lrp=True)
|
198 |
-
ours = GeneratorOurs(model_pert)
|
199 |
-
baselines = GeneratorBaselines(model_pert)
|
200 |
-
oursNoAggAblation = GeneratorOursAblationNoAggregation(model_pert)
|
201 |
-
vqa_dataset = vqa_data.VQADataset(splits="valid")
|
202 |
-
vqa_answers = utils.get_data(VQA_URL)
|
203 |
-
method_name = args.method
|
204 |
-
|
205 |
-
items = vqa_dataset.data
|
206 |
-
random.seed(1234)
|
207 |
-
r = list(range(len(items)))
|
208 |
-
random.shuffle(r)
|
209 |
-
pert_samples_indices = r[:args.num_samples]
|
210 |
-
iterator = tqdm([vqa_dataset.data[i] for i in pert_samples_indices])
|
211 |
-
|
212 |
-
test_type = "positive" if args.is_positive_pert else "negative"
|
213 |
-
modality = "text" if args.is_text_pert else "image"
|
214 |
-
print("runnig {0} pert test for {1} modality with method {2}".format(test_type, modality, args.method))
|
215 |
-
|
216 |
-
for index, item in enumerate(iterator):
|
217 |
-
if method_name == 'transformer_att':
|
218 |
-
R_t_t, R_t_i = baselines.generate_transformer_attr(item)
|
219 |
-
elif method_name == 'attn_gradcam':
|
220 |
-
R_t_t, R_t_i = baselines.generate_attn_gradcam(item)
|
221 |
-
elif method_name == 'partial_lrp':
|
222 |
-
R_t_t, R_t_i = baselines.generate_partial_lrp(item)
|
223 |
-
elif method_name == 'raw_attn':
|
224 |
-
R_t_t, R_t_i = baselines.generate_raw_attn(item)
|
225 |
-
elif method_name == 'rollout':
|
226 |
-
R_t_t, R_t_i = baselines.generate_rollout(item)
|
227 |
-
elif method_name == "ours_with_lrp_no_normalization":
|
228 |
-
R_t_t, R_t_i = ours.generate_ours(item, normalize_self_attention=False)
|
229 |
-
elif method_name == "ours_no_lrp":
|
230 |
-
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False)
|
231 |
-
elif method_name == "ours_no_lrp_no_norm":
|
232 |
-
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, normalize_self_attention=False)
|
233 |
-
elif method_name == "ours_with_lrp":
|
234 |
-
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=True)
|
235 |
-
elif method_name == "ablation_no_self_in_10":
|
236 |
-
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, apply_self_in_rule_10=False)
|
237 |
-
elif method_name == "ablation_no_aggregation":
|
238 |
-
R_t_t, R_t_i = oursNoAggAblation.generate_ours_no_agg(item, use_lrp=False, normalize_self_attention=False)
|
239 |
-
else:
|
240 |
-
print("Please enter a valid method name")
|
241 |
-
return
|
242 |
-
cam_image = R_t_i[0]
|
243 |
-
cam_text = R_t_t[0]
|
244 |
-
cam_image = (cam_image - cam_image.min()) / (cam_image.max() - cam_image.min())
|
245 |
-
cam_text = (cam_text - cam_text.min()) / (cam_text.max() - cam_text.min())
|
246 |
-
if args.is_text_pert:
|
247 |
-
curr_pert_result = model_pert.perturbation_text(item, cam_image, cam_text, args.is_positive_pert)
|
248 |
-
else:
|
249 |
-
curr_pert_result = model_pert.perturbation_image(item, cam_image, cam_text, args.is_positive_pert)
|
250 |
-
curr_pert_result = [round(res / (index+1) * 100, 2) for res in curr_pert_result]
|
251 |
-
iterator.set_description("Acc: {}".format(curr_pert_result))
|
252 |
-
|
253 |
-
if __name__ == "__main__":
|
254 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|