""" Code borrowed from SelfMask: https://github.com/NoelShin/selfmask """ from typing import Optional import torch def compute_pixel_accuracy( pred_mask: torch.Tensor, gt_mask: torch.Tensor, threshold: Optional[float] = 0.5 ) -> torch.Tensor: """ :param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1] :param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1} """ if threshold is not None: binary_pred_mask = pred_mask > threshold else: binary_pred_mask = pred_mask return (binary_pred_mask == gt_mask).to(torch.float32).mean(dim=(-1, -2)).cpu()