Last commit not found
from typing import Dict, Callable | |
import torch | |
from torchmetrics.aggregation import MeanMetric | |
from torchmetrics.classification.accuracy import MulticlassAccuracy | |
from torchmetrics.classification import MulticlassCohenKappa | |
class Metrics: | |
def __init__(self, | |
num_classes: int, | |
labelmap: Dict[int, str], | |
split: str, | |
log_fn: Callable[..., None]) -> None: | |
self.labelmap = labelmap | |
self.loss = MeanMetric(nan_strategy='ignore') | |
self.accuracy = MulticlassAccuracy(num_classes=num_classes) | |
self.per_class_accuracies = MulticlassAccuracy( | |
num_classes=num_classes, average=None) | |
self.kappa = MulticlassCohenKappa(num_classes) | |
self.split = split | |
self.log_fn = log_fn | |
def update(self, | |
loss: torch.Tensor, | |
preds: torch.Tensor, | |
labels: torch.Tensor) -> None: | |
self.loss.update(loss) | |
self.accuracy.update(preds, labels) | |
self.per_class_accuracies.update(preds, labels) | |
self.kappa.update(preds, labels) | |
def log(self) -> None: | |
loss = self.loss.compute() | |
accuracy = self.accuracy.compute() | |
accuracies = self.per_class_accuracies.compute() | |
kappa = self.kappa.compute() | |
mean_accuracy = torch.nanmean(accuracies) | |
self.log_fn(f"{self.split}/loss", loss, sync_dist=True) | |
self.log_fn(f"{self.split}/accuracy", accuracy, sync_dist=True) | |
self.log_fn(f"{self.split}/mean_accuracy", mean_accuracy, sync_dist=True) | |
for i_class, acc in enumerate(accuracies): | |
name = self.labelmap[i_class] | |
self.log_fn(f"{self.split}/acc/{i_class} {name}", acc, sync_dist=True) | |
self.log_fn(f"{self.split}/kappa", kappa, sync_dist=True) | |
def to(self, device) -> 'Metrics': | |
self.loss.to(device) # BUG HERE? should I assign it back? | |
self.accuracy.to(device) | |
self.per_class_accuracies.to(device) | |
self.kappa.to(device) | |
return self | |