File size: 481 Bytes
1803579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
"""

import torch


def compute_mae(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> torch.Tensor:
    """
    :param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1]
    :param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1}
    """
    return torch.mean(
        torch.abs(pred_mask - gt_mask.to(torch.float32)), dim=(-1, -2)
    ).cpu()