File size: 2,304 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import logging
from typing import List
import torch
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch"
__email__ = "[email protected]"
__status__ = "Production"
# ---------------------------------------------------------------------------
# module metrics
#
# General functions to compute custom metrics.
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Module Methods
# ---------------------------------------------------------------------------
EPSILON = 1e-15
# ------------------------------ Metric Functions -------------------------- #
def iou_val(y_true, y_pred):
intersection = np.logical_and(y_true, y_pred)
union = np.logical_or(y_true, y_pred)
iou_score = np.sum(intersection) / np.sum(union)
return iou_score
def acc_val(y_true, y_pred):
return accuracy_score(y_true, y_pred)
def prec_val(y_true, y_pred):
return precision_score(y_true, y_pred, average='macro'), \
precision_score(y_true, y_pred, average=None)
def recall_val(y_true, y_pred):
return recall_score(y_true, y_pred, average='macro'), \
recall_score(y_true, y_pred, average=None)
def find_average(outputs: List, name: str) -> torch.Tensor:
if len(outputs[0][name].shape) == 0:
return torch.stack([x[name] for x in outputs]).mean()
return torch.cat([x[name] for x in outputs]).mean()
def binary_mean_iou(
logits: torch.Tensor,
targets: torch.Tensor
) -> torch.Tensor:
output = (logits > 0).int()
if output.shape != targets.shape:
targets = torch.squeeze(targets, 1)
intersection = (targets * output).sum()
union = targets.sum() + output.sum() - intersection
result = (intersection + EPSILON) / (union + EPSILON)
return result
# -------------------------------------------------------------------------------
# module metrics Unit Tests
# -------------------------------------------------------------------------------
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
|