""" Code borrowed from SelfMask: https://github.com/NoelShin/selfmask """ import torch class FMeasure: def __init__( self, default_thres: float = 0.5, beta_square: float = 0.3, n_bins: int = 255, eps: float = 1e-7, ): """ :param default_thres: a hyperparameter for F-measure that is used to binarize a predicted mask. Default: 0.5 :param beta_square: a hyperparameter for F-measure. Default: 0.3 :param n_bins: the number of thresholds that will be tested for F-max. Default: 255 :param eps: a small value for numerical stability """ self.beta_square = beta_square self.default_thres = default_thres self.eps = eps self.n_bins = n_bins def _compute_precision_recall( self, binary_pred_mask: torch.Tensor, gt_mask: torch.Tensor ) -> torch.Tensor: """ :param binary_pred_mask: (B x H x W) or (H x W) :param gt_mask: (B x H x W) or (H x W), should be the same with binary_pred_mask """ tp = torch.logical_and(binary_pred_mask, gt_mask).sum(dim=(-1, -2)) tp_fp = binary_pred_mask.sum(dim=(-1, -2)) tp_fn = gt_mask.sum(dim=(-1, -2)) prec = tp / (tp_fp + self.eps) recall = tp / (tp_fn + self.eps) return prec, recall def _compute_f_measure( self, pred_mask: torch.Tensor, gt_mask: torch.Tensor, thresholds: torch.Tensor = None, ) -> torch.Tensor: if thresholds is None: binary_pred_mask = pred_mask > self.default_thres else: binary_pred_mask = pred_mask > thresholds prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask) f_measure = ((1 + (self.beta_square**2)) * prec * recall) / ( (self.beta_square**2) * prec + recall + self.eps ) return f_measure.cpu() def _compute_f_max( self, pred_mask: torch.Tensor, gt_mask: torch.Tensor ) -> torch.Tensor: """Compute self.n_bins + 1 F-measures, each of which has a different threshold, then return the maximum F-measure among them. :param pred_mask: (H x W) :param gt_mask: (H x W) """ # pred_masks, gt_masks: H x W -> self.n_bins x H x W pred_masks = pred_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1) gt_masks = gt_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1) # thresholds: self.n_bins x 1 x 1 thresholds = ( torch.arange(0, 1, 1 / self.n_bins) .view(self.n_bins, 1, 1) .to(pred_masks.device) ) # f_measures: self.n_bins f_measures = self._compute_f_measure(pred_masks, gt_masks, thresholds) return torch.max(f_measures).cpu(), f_measures def _compute_f_mean( self, pred_mask: torch.Tensor, gt_mask: torch.Tensor, ) -> torch.Tensor: adaptive_thres = 2 * pred_mask.mean(dim=(-1, -2), keepdim=True) binary_pred_mask = pred_mask > adaptive_thres prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask) f_mean = ((1 + (self.beta_square**2)) * prec * recall) / ( (self.beta_square**2) * prec + recall + self.eps ) return f_mean.cpu() def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> dict: """ :param pred_mask: (H x W) a normalized prediction mask with values in [0, 1] :param gt_mask: (H x W) a binary ground truth mask with values in {0, 1} :return: a dictionary with keys being "f_measure" and "f_max" and values being the respective values. """ outputs: dict = dict() for k in ("f_measure", "f_mean"): outputs.update({k: getattr(self, f"_compute_{k}")(pred_mask, gt_mask)}) f_max_, all_f = self._compute_f_max(pred_mask, gt_mask) outputs["f_max"] = f_max_ outputs["all_f"] = all_f # List of all f values for all thresholds return outputs