WwYc commited on
Commit
fa96553
·
verified ·
1 Parent(s): d54d66e

Update generic.py

Browse files
Files changed (1) hide show
  1. generic.py +25 -33
generic.py CHANGED
@@ -88,50 +88,42 @@ class ModelUsage:
88
  return self.output
89
 
90
 
91
- def save_image_vis(image_file_path, bbox_scores):
92
- bbox_scores = image_scores
93
- _, top_bboxes_indices = bbox_scores.topk(k=1, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  img = cv2.imread(image_file_path)
95
  mask = torch.zeros(img.shape[0], img.shape[1])
96
- for index in range(len(bbox_scores)):
97
  [x, y, w, h] = model_lrp.bboxes[0][index]
98
  curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
99
- new_score_tensor = torch.ones_like(curr_score_tensor) * bbox_scores[index].item()
100
  mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)])
101
  mask = (mask - mask.min()) / (mask.max() - mask.min())
102
  mask = mask.unsqueeze_(-1)
103
  mask = mask.expand(img.shape)
104
  img = img * mask.cpu().data.numpy()
 
105
  cv2.imwrite(
106
  'lxmert/lxmert/experiments/paper/new.jpg', img)
107
- return img
 
 
 
 
 
108
 
109
 
110
- model_lrp = ModelUsage(use_lrp=True)
111
- lrp = GeneratorOurs(model_lrp)
112
- baselines = GeneratorBaselines(model_lrp)
113
- vqa_answers = utils.get_data(VQA_URL)
114
 
115
- image_ids = [
116
- # giraffe
117
- 'COCO_val2014_000000185590',
118
- # baseball
119
- 'COCO_val2014_000000127510',
120
- # bath
121
- 'COCO_val2014_000000324266',
122
- # frisbee
123
- 'COCO_val2014_000000200717'
124
- ]
125
-
126
- test_questions_for_images = [
127
- ################## paper samples
128
- # giraffe
129
- "is the animal eating?",
130
- # baseball
131
- "did he catch the ball?",
132
- # bath
133
- "is the tub white ?",
134
- # frisbee
135
- "did the man just catch the frisbee?"
136
- ################## paper samples
137
- ]
 
88
  return self.output
89
 
90
 
91
+ model_lrp = ModelUsage(use_lrp=True)
92
+ lrp = GeneratorOurs(model_lrp)
93
+ baselines = GeneratorBaselines(model_lrp)
94
+ vqa_answers = utils.get_data(VQA_URL)
95
+
96
+
97
+
98
+
99
+ def save_image_vis(image_file_path, question):
100
+ R_t_t, R_t_i = lrp.generate_ours((image_file_path, quewtion), use_lrp=False,
101
+ normalize_self_attention=True,
102
+ method_name="ours")
103
+ image_scores = R_t_i[0]
104
+ text_scores = R_t_t[0]
105
+ # bbox_scores = image_scores
106
+ _, top_bboxes_indices = image_scores.topk(k=1, dim=-1)
107
  img = cv2.imread(image_file_path)
108
  mask = torch.zeros(img.shape[0], img.shape[1])
109
+ for index in range(len(image_scores)):
110
  [x, y, w, h] = model_lrp.bboxes[0][index]
111
  curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
112
+ new_score_tensor = torch.ones_like(curr_score_tensor) * image_scores[index].item()
113
  mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)])
114
  mask = (mask - mask.min()) / (mask.max() - mask.min())
115
  mask = mask.unsqueeze_(-1)
116
  mask = mask.expand(img.shape)
117
  img = img * mask.cpu().data.numpy()
118
+ # img = Image.fromarray(np.uint8(img)).convert('RGB')
119
  cv2.imwrite(
120
  'lxmert/lxmert/experiments/paper/new.jpg', img)
121
+ text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
122
+ vis_data_records = [visualization.VisualizationDataRecord(text_scores, 0, 0, 0, 0, 0, model_lrp.question_tokens, 1)]
123
+ html1 = visualization.visualize_text(vis_data_records)
124
+ return html1.data
125
+
126
+
127
 
128
 
 
 
 
 
129