File size: 5,569 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import copy
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

import torch
import torch.utils.data
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from prettytable import PrettyTable

import re
import json

from box_ops import generalized_box_iou, box_iou
from decode_string import decode_bbox_from_caption

import pdb

class RefExpEvaluatorFromTxt(object):
    def __init__(self, refexp_gt_path, k=(1, -1), thresh_iou=0.5):
        assert isinstance(k, (list, tuple))
        with open(refexp_gt_path, 'r') as f:
            self.refexp_gt = json.load(f)
        self.img_ids = [item['id'] for item in self.refexp_gt['images']]
        print(f"Load {len(self.img_ids)} images")
        print(f"Load {len(self.refexp_gt['annotations'])} annotations")
        self.k = k
        self.thresh_iou = thresh_iou

    def summarize(self,
                  prediction_file: str,
                  quantized_size: int = 32,
                  verbose: bool = False,):
        
        # get the predictions
        with open(prediction_file, 'r', encoding='utf-8') as f:
            predict_all_lines = f.readlines()
        # filter the invaild lines for predict_all_lines
        filter_prediction_lines = []
        for line in predict_all_lines:
            line_pieces = line.strip('\n').split('\t')
            if 'H-' in line_pieces[0]:
                if line_pieces[0].split('-')[-1].isdigit():
                    filter_prediction_lines.append(line)
        
        predict_all_lines = filter_prediction_lines
        predict_index = 0
        
        dataset2score = {
            "refcoco": {k: 0.0 for k in self.k},
            "refcoco+": {k: 0.0 for k in self.k},
            "refcocog": {k: 0.0 for k in self.k},
        }
        dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0}
        for item_img, item_ann in tqdm(zip(self.refexp_gt['images'], self.refexp_gt['annotations'])):
            # quit when evaluating all predictions
            if predict_index == len(predict_all_lines):
                    break
                
            if item_img['id'] != item_ann['image_id']:
                raise ValueError(f"Ann\n{item_ann} \nis not matched\n {item_img}")
            
            dataset_name = item_img['dataset_name']
            img_height = item_img['height']
            img_width = item_img['width']
            caption = item_img['caption']
            target_bbox = item_ann["bbox"]
            converted_bbox = [
                target_bbox[0],
                target_bbox[1],
                target_bbox[2] + target_bbox[0],
                target_bbox[3] + target_bbox[1],
            ]
            target_bbox = torch.as_tensor(converted_bbox).view(-1, 4)
            
            
            prediction_line = predict_all_lines[predict_index].split('</image>')[-1]
            predict_index += 1
            
            collect_entity_location = decode_bbox_from_caption(prediction_line, quantized_size=quantized_size, verbose=verbose)
            
            predict_boxes = []
            for (p_pred, p_x1, p_y1, p_x2, p_y2) in collect_entity_location:
                if p_pred.strip() != caption.strip():
                    continue
                else:
                    pred_box = [p_x1 * img_width, p_y1 * img_height, p_x2 * img_width, p_y2 * img_height]
                    predict_boxes.append(pred_box)
                    
            if len(predict_boxes) == 0:
                print(f"Can't find valid bbox for the given phrase {caption}, \n{collect_entity_location}")
                print(f"We set a 0-area box to calculate result")
                predict_boxes = [[0., 0., 0., 0.]]
                
            predict_boxes = torch.as_tensor(predict_boxes).view(-1, 4)
            
            iou, _ = box_iou(predict_boxes, target_bbox)
            mean_iou, _ = box_iou(predict_boxes.mean(0).view(-1, 4), target_bbox)
            for k in self.k:
                if k == 'upper bound':
                    if max(iou) >= self.thresh_iou:
                        dataset2score[dataset_name][k] += 1.0
                elif k == 'mean':
                    if max(mean_iou) >= self.thresh_iou:
                        dataset2score[dataset_name][k] += 1.0
                else:
                    if max(iou[0, :k]) >= self.thresh_iou:
                        dataset2score[dataset_name][k] += 1.0

            dataset2count[dataset_name] += 1.0

        for key, value in dataset2score.items():
            for k in self.k:
                try:
                    value[k] /= dataset2count[key]
                except:
                    pass
                
        results = {}
        for key, value in dataset2score.items():
            results[key] = sorted([v for k, v in value.items()])
            print(f" Dataset: {key} - Precision @ 1, mean, all: {results[key]} \n")

        return results


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('prediction_file', help='prediction_file')
    parser.add_argument('annotation_file', default='/path/to/mdetr_processed_json_annotations', help='annotation_file')
    parser.add_argument('--quantized_size', default=32, type=int)
    
    args = parser.parse_args()
    
    evaluator = RefExpEvaluatorFromTxt(
        refexp_gt_path=args.annotation_file, 
        k=(1, 'mean', 'upper bound'), 
        thresh_iou=0.5,
    )
    
    evaluator.summarize(args.prediction_file, args.quantized_size, verbose=False)