HakimAiV2 / datasets /evaluation /classification_evaluation.py
scdrand23's picture
not working version
814a594
# Copyright (c) Facebook, Inc. and its affiliates.
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by Xueyan Zou ([email protected])
# --------------------------------------------------------
import torch
import logging
from detectron2.evaluation.evaluator import DatasetEvaluator
from utilities.misc import AverageMeter
from utilities.distributed import get_world_size
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if isinstance(output, list):
output = output[-1]
n_classes = output.size()[1]
maxk = min(max(topk), n_classes)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size).item())
return res
class ClassificationEvaluator(DatasetEvaluator):
def __init__(self, *args):
self.top1 = AverageMeter()
self.top5 = AverageMeter()
self._logger = logging.getLogger(__name__)
def reset(self):
self.top1.reset()
self.top5.reset()
def process(self, inputs, outputs):
logits = torch.stack([o['pred_class'] for o in outputs])
y = torch.tensor([t['class_id'] for t in inputs], device=logits.device)
prec1, prec5 = accuracy(logits, y, (1, 5))
self.top1.update(prec1, y.size(0))
self.top5.update(prec5, y.size(0))
def evaluate(self):
if get_world_size() > 1:
tmp_tensor = torch.tensor(
[self.top1.sum, self.top5.sum, self.top1.count],
device=torch.cuda.current_device()
)
torch.distributed.all_reduce(
tmp_tensor, torch.distributed.ReduceOp.SUM
)
top1_sum, top5_sum, count = tmp_tensor.tolist()
else:
top1_sum = self.top1.sum
top5_sum = self.top5.sum
count = self.top1.count
results = {}
scores = {
'top1': top1_sum / count,
"top5": top5_sum / count
}
results['class'] = scores
self._logger.info(results)
return results