File size: 4,067 Bytes
1803579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
"""

import torch


class FMeasure:
    def __init__(
        self,
        default_thres: float = 0.5,
        beta_square: float = 0.3,
        n_bins: int = 255,
        eps: float = 1e-7,
    ):
        """
        :param default_thres: a hyperparameter for F-measure that is used to binarize a predicted mask. Default: 0.5
        :param beta_square: a hyperparameter for F-measure. Default: 0.3
        :param n_bins: the number of thresholds that will be tested for F-max. Default: 255
        :param eps: a small value for numerical stability
        """

        self.beta_square = beta_square
        self.default_thres = default_thres
        self.eps = eps
        self.n_bins = n_bins

    def _compute_precision_recall(
        self, binary_pred_mask: torch.Tensor, gt_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        :param binary_pred_mask: (B x H x W) or (H x W)
        :param gt_mask: (B x H x W) or (H x W), should be the same with binary_pred_mask
        """
        tp = torch.logical_and(binary_pred_mask, gt_mask).sum(dim=(-1, -2))
        tp_fp = binary_pred_mask.sum(dim=(-1, -2))
        tp_fn = gt_mask.sum(dim=(-1, -2))

        prec = tp / (tp_fp + self.eps)
        recall = tp / (tp_fn + self.eps)
        return prec, recall

    def _compute_f_measure(
        self,
        pred_mask: torch.Tensor,
        gt_mask: torch.Tensor,
        thresholds: torch.Tensor = None,
    ) -> torch.Tensor:
        if thresholds is None:
            binary_pred_mask = pred_mask > self.default_thres
        else:
            binary_pred_mask = pred_mask > thresholds

        prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
        f_measure = ((1 + (self.beta_square**2)) * prec * recall) / (
            (self.beta_square**2) * prec + recall + self.eps
        )
        return f_measure.cpu()

    def _compute_f_max(
        self, pred_mask: torch.Tensor, gt_mask: torch.Tensor
    ) -> torch.Tensor:
        """Compute self.n_bins + 1  F-measures, each of which has a different threshold, then return the maximum
        F-measure among them.

        :param pred_mask: (H x W)
        :param gt_mask: (H x W)
        """

        # pred_masks, gt_masks: H x W -> self.n_bins x H x W
        pred_masks = pred_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
        gt_masks = gt_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)

        # thresholds: self.n_bins x 1 x 1
        thresholds = (
            torch.arange(0, 1, 1 / self.n_bins)
            .view(self.n_bins, 1, 1)
            .to(pred_masks.device)
        )

        # f_measures: self.n_bins
        f_measures = self._compute_f_measure(pred_masks, gt_masks, thresholds)
        return torch.max(f_measures).cpu(), f_measures

    def _compute_f_mean(
        self,
        pred_mask: torch.Tensor,
        gt_mask: torch.Tensor,
    ) -> torch.Tensor:
        adaptive_thres = 2 * pred_mask.mean(dim=(-1, -2), keepdim=True)
        binary_pred_mask = pred_mask > adaptive_thres

        prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
        f_mean = ((1 + (self.beta_square**2)) * prec * recall) / (
            (self.beta_square**2) * prec + recall + self.eps
        )
        return f_mean.cpu()

    def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> dict:
        """
        :param pred_mask: (H x W) a normalized prediction mask with values in [0, 1]
        :param gt_mask: (H x W) a binary ground truth mask with values in {0, 1}
        :return: a dictionary with keys being "f_measure" and "f_max" and values being the respective values.
        """
        outputs: dict = dict()
        for k in ("f_measure", "f_mean"):
            outputs.update({k: getattr(self, f"_compute_{k}")(pred_mask, gt_mask)})

        f_max_, all_f = self._compute_f_max(pred_mask, gt_mask)
        outputs["f_max"] = f_max_
        outputs["all_f"] = all_f  # List of all f values for all thresholds
        return outputs