WwYc commited on
Commit
2535139
·
verified ·
1 Parent(s): 08d7644

Delete lxmert/perturbation.py

Browse files
Files changed (1) hide show
  1. lxmert/perturbation.py +0 -254
lxmert/perturbation.py DELETED
@@ -1,254 +0,0 @@
1
- from lxmert.lxmert.src.tasks import vqa_data
2
- from lxmert.lxmert.src.modeling_frcnn import GeneralizedRCNN
3
- import lxmert.lxmert.src.vqa_utils as utils
4
- from lxmert.lxmert.src.processing_image import Preprocess
5
- from transformers import LxmertTokenizer
6
- from lxmert.lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
7
- from lxmert.lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
8
- from tqdm import tqdm
9
- from lxmert.lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
10
- import random
11
- from lxmert.lxmert.src.param import args
12
-
13
- OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt"
14
- ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt"
15
- VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"
16
-
17
- class ModelPert:
18
- def __init__(self, COCO_val_path, use_lrp=False):
19
- self.COCO_VAL_PATH = COCO_val_path
20
- self.vqa_answers = utils.get_data(VQA_URL)
21
-
22
- # load models and model components
23
- self.frcnn_cfg = utils.Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
24
- self.frcnn_cfg.MODEL.DEVICE = "cuda"
25
-
26
- self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
27
-
28
- self.image_preprocess = Preprocess(self.frcnn_cfg)
29
-
30
- self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
31
-
32
- if use_lrp:
33
- self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
34
- else:
35
- self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
36
-
37
- self.lxmert_vqa.eval()
38
- self.model = self.lxmert_vqa
39
-
40
- self.vqa_dataset = vqa_data.VQADataset(splits="valid")
41
-
42
- self.pert_steps = [0, 0.25, 0.5, 0.75, 0.8, 0.85, 0.9, 0.95, 1]
43
- self.pert_acc = [0] * len(self.pert_steps)
44
-
45
- def forward(self, item):
46
- image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
47
- self.image_file_path = image_file_path
48
- self.image_id = item['img_id']
49
- # run frcnn
50
- images, sizes, scales_yx = self.image_preprocess(image_file_path)
51
- output_dict = self.frcnn(
52
- images,
53
- sizes,
54
- scales_yx=scales_yx,
55
- padding="max_detections",
56
- max_detections= self.frcnn_cfg.max_detections,
57
- return_tensors="pt"
58
- )
59
- inputs = self.lxmert_tokenizer(
60
- item['sent'],
61
- truncation=True,
62
- return_token_type_ids=True,
63
- return_attention_mask=True,
64
- add_special_tokens=True,
65
- return_tensors="pt"
66
- )
67
- self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
68
- self.text_len = len(self.question_tokens)
69
- # Very important that the boxes are normalized
70
- normalized_boxes = output_dict.get("normalized_boxes")
71
- features = output_dict.get("roi_features")
72
- self.image_boxes_len = features.shape[1]
73
- self.bboxes = output_dict.get("boxes")
74
- self.output = self.lxmert_vqa(
75
- input_ids=inputs.input_ids.to("cuda"),
76
- attention_mask=inputs.attention_mask.to("cuda"),
77
- visual_feats=features.to("cuda"),
78
- visual_pos=normalized_boxes.to("cuda"),
79
- token_type_ids=inputs.token_type_ids.to("cuda"),
80
- return_dict=True,
81
- output_attentions=False,
82
- )
83
- return self.output
84
-
85
- def perturbation_image(self, item, cam_image, cam_text, is_positive_pert=False):
86
- if is_positive_pert:
87
- cam_image = cam_image * (-1)
88
- image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
89
- # run frcnn
90
- images, sizes, scales_yx = self.image_preprocess(image_file_path)
91
- output_dict = self.frcnn(
92
- images,
93
- sizes,
94
- scales_yx=scales_yx,
95
- padding="max_detections",
96
- max_detections=self.frcnn_cfg.max_detections,
97
- return_tensors="pt"
98
- )
99
- inputs = self.lxmert_tokenizer(
100
- item['sent'],
101
- truncation=True,
102
- return_token_type_ids=True,
103
- return_attention_mask=True,
104
- add_special_tokens=True,
105
- return_tensors="pt"
106
- )
107
- # Very important that the boxes are normalized
108
- normalized_boxes = output_dict.get("normalized_boxes")
109
- features = output_dict.get("roi_features")
110
- for step_idx, step in enumerate(self.pert_steps):
111
- # find top step boxes
112
- curr_num_boxes = int((1 - step) * self.image_boxes_len)
113
- _, top_bboxes_indices = cam_image.topk(k=curr_num_boxes, dim=-1)
114
- top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
115
-
116
- curr_features = features[:, top_bboxes_indices, :]
117
- curr_pos = normalized_boxes[:, top_bboxes_indices, :]
118
-
119
- output = self.lxmert_vqa(
120
- input_ids=inputs.input_ids.to("cuda"),
121
- attention_mask=inputs.attention_mask.to("cuda"),
122
- visual_feats=curr_features.to("cuda"),
123
- visual_pos=curr_pos.to("cuda"),
124
- token_type_ids=inputs.token_type_ids.to("cuda"),
125
- return_dict=True,
126
- output_attentions=False,
127
- )
128
-
129
- answer = self.vqa_answers[output.question_answering_score.argmax()]
130
- accuracy = item["label"].get(answer, 0)
131
- self.pert_acc[step_idx] += accuracy
132
-
133
- return self.pert_acc
134
-
135
- def perturbation_text(self, item, cam_image, cam_text, is_positive_pert=False):
136
- if is_positive_pert:
137
- cam_text = cam_text * (-1)
138
- image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
139
- # run frcnn
140
- images, sizes, scales_yx = self.image_preprocess(image_file_path)
141
- output_dict = self.frcnn(
142
- images,
143
- sizes,
144
- scales_yx=scales_yx,
145
- padding="max_detections",
146
- max_detections=self.frcnn_cfg.max_detections,
147
- return_tensors="pt"
148
- )
149
- inputs = self.lxmert_tokenizer(
150
- item['sent'],
151
- truncation=True,
152
- return_token_type_ids=True,
153
- return_attention_mask=True,
154
- add_special_tokens=True,
155
- return_tensors="pt"
156
- )
157
- # Very important that the boxes are normalized
158
- normalized_boxes = output_dict.get("normalized_boxes")
159
- features = output_dict.get("roi_features")
160
- for step_idx, step in enumerate(self.pert_steps):
161
- # we must keep the [CLS] token in order to have the classification
162
- # we also keep the [SEP] token
163
- cam_pure_text = cam_text[1:-1]
164
- text_len = cam_pure_text.shape[0]
165
- # find top step tokens, without the [CLS] token and the [SEP] token
166
- curr_num_tokens = int((1 - step) * text_len)
167
- _, top_bboxes_indices = cam_pure_text.topk(k=curr_num_tokens, dim=-1)
168
- top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
169
-
170
- # add back [CLS], [SEP] tokens
171
- top_bboxes_indices = [0, cam_text.shape[0] - 1] +\
172
- [top_bboxes_indices[i] + 1 for i in range(len(top_bboxes_indices))]
173
- # text tokens must be sorted for positional embedding to work
174
- top_bboxes_indices = sorted(top_bboxes_indices)
175
-
176
- curr_input_ids = inputs.input_ids[:, top_bboxes_indices]
177
- curr_attention_mask = inputs.attention_mask[:, top_bboxes_indices]
178
- curr_token_ids = inputs.token_type_ids[:, top_bboxes_indices]
179
-
180
- output = self.lxmert_vqa(
181
- input_ids=curr_input_ids.to("cuda"),
182
- attention_mask=curr_attention_mask.to("cuda"),
183
- visual_feats=features.to("cuda"),
184
- visual_pos=normalized_boxes.to("cuda"),
185
- token_type_ids=curr_token_ids.to("cuda"),
186
- return_dict=True,
187
- output_attentions=False,
188
- )
189
-
190
- answer = self.vqa_answers[output.question_answering_score.argmax()]
191
- accuracy = item["label"].get(answer, 0)
192
- self.pert_acc[step_idx] += accuracy
193
-
194
- return self.pert_acc
195
-
196
- def main(args):
197
- model_pert = ModelPert(args.COCO_path, use_lrp=True)
198
- ours = GeneratorOurs(model_pert)
199
- baselines = GeneratorBaselines(model_pert)
200
- oursNoAggAblation = GeneratorOursAblationNoAggregation(model_pert)
201
- vqa_dataset = vqa_data.VQADataset(splits="valid")
202
- vqa_answers = utils.get_data(VQA_URL)
203
- method_name = args.method
204
-
205
- items = vqa_dataset.data
206
- random.seed(1234)
207
- r = list(range(len(items)))
208
- random.shuffle(r)
209
- pert_samples_indices = r[:args.num_samples]
210
- iterator = tqdm([vqa_dataset.data[i] for i in pert_samples_indices])
211
-
212
- test_type = "positive" if args.is_positive_pert else "negative"
213
- modality = "text" if args.is_text_pert else "image"
214
- print("runnig {0} pert test for {1} modality with method {2}".format(test_type, modality, args.method))
215
-
216
- for index, item in enumerate(iterator):
217
- if method_name == 'transformer_att':
218
- R_t_t, R_t_i = baselines.generate_transformer_attr(item)
219
- elif method_name == 'attn_gradcam':
220
- R_t_t, R_t_i = baselines.generate_attn_gradcam(item)
221
- elif method_name == 'partial_lrp':
222
- R_t_t, R_t_i = baselines.generate_partial_lrp(item)
223
- elif method_name == 'raw_attn':
224
- R_t_t, R_t_i = baselines.generate_raw_attn(item)
225
- elif method_name == 'rollout':
226
- R_t_t, R_t_i = baselines.generate_rollout(item)
227
- elif method_name == "ours_with_lrp_no_normalization":
228
- R_t_t, R_t_i = ours.generate_ours(item, normalize_self_attention=False)
229
- elif method_name == "ours_no_lrp":
230
- R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False)
231
- elif method_name == "ours_no_lrp_no_norm":
232
- R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, normalize_self_attention=False)
233
- elif method_name == "ours_with_lrp":
234
- R_t_t, R_t_i = ours.generate_ours(item, use_lrp=True)
235
- elif method_name == "ablation_no_self_in_10":
236
- R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, apply_self_in_rule_10=False)
237
- elif method_name == "ablation_no_aggregation":
238
- R_t_t, R_t_i = oursNoAggAblation.generate_ours_no_agg(item, use_lrp=False, normalize_self_attention=False)
239
- else:
240
- print("Please enter a valid method name")
241
- return
242
- cam_image = R_t_i[0]
243
- cam_text = R_t_t[0]
244
- cam_image = (cam_image - cam_image.min()) / (cam_image.max() - cam_image.min())
245
- cam_text = (cam_text - cam_text.min()) / (cam_text.max() - cam_text.min())
246
- if args.is_text_pert:
247
- curr_pert_result = model_pert.perturbation_text(item, cam_image, cam_text, args.is_positive_pert)
248
- else:
249
- curr_pert_result = model_pert.perturbation_image(item, cam_image, cam_text, args.is_positive_pert)
250
- curr_pert_result = [round(res / (index+1) * 100, 2) for res in curr_pert_result]
251
- iterator.set_description("Acc: {}".format(curr_pert_result))
252
-
253
- if __name__ == "__main__":
254
- main(args)