File size: 2,052 Bytes
b34d1d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from seg.models.utils.no_obj import NO_OBJ
from seg.models.utils.pan_seg_transform import INSTANCE_OFFSET_HB
from panopticapi.evaluation import PQStat

NO_OBJ_ID = NO_OBJ * INSTANCE_OFFSET_HB


class IoUObj:
    def __init__(self, intersection: int = 0, union: int = 0):
        self.intersection = intersection
        self.union = union

    def __iadd__(self, other):
        self.intersection += other.intersection
        self.union += other.union
        return self

    def __isub__(self, other):
        self.intersection -= other.intersection
        self.union -= other.union
        return self

    def is_legal(self):
        return self.intersection >= 0 and self.union >= 0

    @property
    def iou(self):
        return self.intersection / self.union


def cal_pq(global_intersection_info, classes):
    num_classes = len(classes)
    gt_matched = set()
    pred_matched = set()

    gt_all = set()
    pred_all = set()

    pq_stat = PQStat()
    for gt_id, pred_id in global_intersection_info:
        gt_cat = gt_id // INSTANCE_OFFSET_HB
        pred_cat = pred_id // INSTANCE_OFFSET_HB
        assert pred_cat < num_classes
        if global_intersection_info[gt_id, pred_id].union == 0:
            continue
        if gt_cat == NO_OBJ:
            continue
        gt_all.add(gt_id)
        pred_all.add(pred_id)
        if gt_cat != pred_cat:
            continue
        iou = global_intersection_info[gt_id, pred_id].iou
        if iou > 0.5:
            pq_stat[gt_cat].tp += 1
            pq_stat[gt_cat].iou += iou
            gt_matched.add(gt_id)
            pred_matched.add(pred_id)

    for gt_id in gt_all:
        gt_cat = gt_id // INSTANCE_OFFSET_HB
        if gt_id in gt_matched:
            continue
        pq_stat[gt_cat].fn += 1

    for pred_id in pred_all:
        pred_cat = pred_id // INSTANCE_OFFSET_HB
        if pred_id in pred_matched:
            continue
        if global_intersection_info[NO_OBJ_ID, pred_id].iou > 0.5:
            continue
        pq_stat[pred_cat].fp += 1

    return pq_stat