Spaces:
Runtime error
Runtime error
from seg.models.utils.no_obj import NO_OBJ | |
from seg.models.utils.pan_seg_transform import INSTANCE_OFFSET_HB | |
from panopticapi.evaluation import PQStat | |
NO_OBJ_ID = NO_OBJ * INSTANCE_OFFSET_HB | |
class IoUObj: | |
def __init__(self, intersection: int = 0, union: int = 0): | |
self.intersection = intersection | |
self.union = union | |
def __iadd__(self, other): | |
self.intersection += other.intersection | |
self.union += other.union | |
return self | |
def __isub__(self, other): | |
self.intersection -= other.intersection | |
self.union -= other.union | |
return self | |
def is_legal(self): | |
return self.intersection >= 0 and self.union >= 0 | |
def iou(self): | |
return self.intersection / self.union | |
def cal_pq(global_intersection_info, classes): | |
num_classes = len(classes) | |
gt_matched = set() | |
pred_matched = set() | |
gt_all = set() | |
pred_all = set() | |
pq_stat = PQStat() | |
for gt_id, pred_id in global_intersection_info: | |
gt_cat = gt_id // INSTANCE_OFFSET_HB | |
pred_cat = pred_id // INSTANCE_OFFSET_HB | |
assert pred_cat < num_classes | |
if global_intersection_info[gt_id, pred_id].union == 0: | |
continue | |
if gt_cat == NO_OBJ: | |
continue | |
gt_all.add(gt_id) | |
pred_all.add(pred_id) | |
if gt_cat != pred_cat: | |
continue | |
iou = global_intersection_info[gt_id, pred_id].iou | |
if iou > 0.5: | |
pq_stat[gt_cat].tp += 1 | |
pq_stat[gt_cat].iou += iou | |
gt_matched.add(gt_id) | |
pred_matched.add(pred_id) | |
for gt_id in gt_all: | |
gt_cat = gt_id // INSTANCE_OFFSET_HB | |
if gt_id in gt_matched: | |
continue | |
pq_stat[gt_cat].fn += 1 | |
for pred_id in pred_all: | |
pred_cat = pred_id // INSTANCE_OFFSET_HB | |
if pred_id in pred_matched: | |
continue | |
if global_intersection_info[NO_OBJ_ID, pred_id].iou > 0.5: | |
continue | |
pq_stat[pred_cat].fp += 1 | |
return pq_stat | |