HakimAiV2 / datasets /evaluation /interactive_evaluation.py
scdrand23's picture
not working version
814a594
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import os
import numpy as np
import torch
from torchvision.ops import box_iou
from detectron2.structures import BoxMode
from detectron2.data import MetadataCatalog
from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize
from detectron2.evaluation.evaluator import DatasetEvaluator
class InteractiveEvaluator(DatasetEvaluator):
"""
Evaluate point interactive IoU metrics.
"""
def __init__(
self,
dataset_name,
output_dir,
max_clicks=20,
iou_iter=1,
compute_box=False,
distributed=True,
):
self._logger = logging.getLogger(__name__)
self._dataset_name = dataset_name
self._distributed = distributed
self._cpu_device = torch.device("cpu")
self._output_dir = output_dir
self.max_clicks = max_clicks
self.iou_iter = iou_iter
meta = MetadataCatalog.get(dataset_name)
def reset(self):
self.iou_list = []
self.num_samples = 0
self.all_ious = [0.5, 0.8, 0.85, 0.9]
def process(self, inputs, outputs):
self.iou_list += [o['mask_iou'] for o in outputs]
self.num_samples += len(outputs)
def compute_noc(self):
def _get_noc(iou_arr, iou_thr):
vals = iou_arr >= iou_thr
return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks
noc_list = {}
for iou_thr in self.all_ious:
scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list]
noc_list[str(iou_thr)] = scores_arr
iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1]
noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()}
if self._distributed:
num_samples = sum(all_gather(self.num_samples))
noc_list_sum_gather = all_gather(noc_list_sum)
iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu())
noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]}
for nlg in noc_list_sum_gather:
for key, value in nlg.items():
noc_list_sum[key] += value
pred_noc = {}
if self._distributed and (not is_main_process()):
return pred_noc
for key, value in noc_list_sum.items():
pred_noc[key] = value / num_samples
pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples
return pred_noc
def evaluate(self):
pred_noc = self.compute_noc()
if self._distributed and (not is_main_process()):
return
def draw_iou_curve(iou_list, save_dir):
iou_list = torch.stack(iou_list, dim=0)
iou_list = iou_list.mean(dim=0).cpu().numpy()
# draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib
import matplotlib.pyplot as plt
plt.figure()
plt.plot(range(1, self.max_clicks+1), iou_list)
plt.xlabel('Number of clicks')
plt.ylabel('IoU')
# create directory if not exist
import os
output_dir = os.path.join(save_dir, 'iou_by_clicks')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# get current time and format in 10 digits
import time
current_time = time.time()
current_time = int(current_time)
current_time = str(current_time)
# save iou curve
plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time)))
draw_iou_curve(self.iou_list, self._output_dir)
results = {}
for idx in range(len(self.all_ious)):
result_str = 'noc@{}'.format(self.all_ious[idx])
results[result_str] = pred_noc[str(self.all_ious[idx])]
results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter']
self._logger.info(results)
return {'interactive': results}