|
|
|
import logging |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
from torchvision.ops import box_iou |
|
|
|
from detectron2.structures import BoxMode |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize |
|
from detectron2.evaluation.evaluator import DatasetEvaluator |
|
|
|
|
|
class InteractiveEvaluator(DatasetEvaluator): |
|
""" |
|
Evaluate point interactive IoU metrics. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset_name, |
|
output_dir, |
|
max_clicks=20, |
|
iou_iter=1, |
|
compute_box=False, |
|
distributed=True, |
|
): |
|
self._logger = logging.getLogger(__name__) |
|
self._dataset_name = dataset_name |
|
self._distributed = distributed |
|
self._cpu_device = torch.device("cpu") |
|
self._output_dir = output_dir |
|
|
|
self.max_clicks = max_clicks |
|
self.iou_iter = iou_iter |
|
meta = MetadataCatalog.get(dataset_name) |
|
|
|
def reset(self): |
|
self.iou_list = [] |
|
self.num_samples = 0 |
|
self.all_ious = [0.5, 0.8, 0.85, 0.9] |
|
|
|
def process(self, inputs, outputs): |
|
self.iou_list += [o['mask_iou'] for o in outputs] |
|
self.num_samples += len(outputs) |
|
|
|
def compute_noc(self): |
|
def _get_noc(iou_arr, iou_thr): |
|
vals = iou_arr >= iou_thr |
|
return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks |
|
|
|
noc_list = {} |
|
for iou_thr in self.all_ious: |
|
scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list] |
|
noc_list[str(iou_thr)] = scores_arr |
|
|
|
iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1] |
|
noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()} |
|
|
|
if self._distributed: |
|
num_samples = sum(all_gather(self.num_samples)) |
|
noc_list_sum_gather = all_gather(noc_list_sum) |
|
iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu()) |
|
|
|
noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]} |
|
for nlg in noc_list_sum_gather: |
|
for key, value in nlg.items(): |
|
noc_list_sum[key] += value |
|
|
|
pred_noc = {} |
|
if self._distributed and (not is_main_process()): |
|
return pred_noc |
|
|
|
for key, value in noc_list_sum.items(): |
|
pred_noc[key] = value / num_samples |
|
|
|
pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples |
|
return pred_noc |
|
|
|
def evaluate(self): |
|
pred_noc = self.compute_noc() |
|
|
|
if self._distributed and (not is_main_process()): |
|
return |
|
|
|
def draw_iou_curve(iou_list, save_dir): |
|
iou_list = torch.stack(iou_list, dim=0) |
|
iou_list = iou_list.mean(dim=0).cpu().numpy() |
|
|
|
import matplotlib.pyplot as plt |
|
plt.figure() |
|
plt.plot(range(1, self.max_clicks+1), iou_list) |
|
plt.xlabel('Number of clicks') |
|
plt.ylabel('IoU') |
|
|
|
|
|
|
|
import os |
|
output_dir = os.path.join(save_dir, 'iou_by_clicks') |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
|
|
import time |
|
current_time = time.time() |
|
current_time = int(current_time) |
|
current_time = str(current_time) |
|
|
|
|
|
plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time))) |
|
|
|
draw_iou_curve(self.iou_list, self._output_dir) |
|
results = {} |
|
for idx in range(len(self.all_ious)): |
|
result_str = 'noc@{}'.format(self.all_ious[idx]) |
|
results[result_str] = pred_noc[str(self.all_ious[idx])] |
|
|
|
results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter'] |
|
|
|
self._logger.info(results) |
|
return {'interactive': results} |