Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Functions for computing metrics.""" | |
import torch | |
import numpy as np | |
def topks_correct(preds, labels, ks): | |
""" | |
Given the predictions, labels, and a list of top-k values, compute the | |
number of correct predictions for each top-k value. | |
Args: | |
preds (array): array of predictions. Dimension is batchsize | |
N x ClassNum. | |
labels (array): array of labels. Dimension is batchsize N. | |
ks (list): list of top-k values. For example, ks = [1, 5] correspods | |
to top-1 and top-5. | |
Returns: | |
topks_correct (list): list of numbers, where the `i`-th entry | |
corresponds to the number of top-`ks[i]` correct predictions. | |
""" | |
assert preds.size(0) == labels.size( | |
0 | |
), "Batch dim of predictions and labels must match" | |
# Find the top max_k predictions for each sample | |
_top_max_k_vals, top_max_k_inds = torch.topk( | |
preds, max(ks), dim=1, largest=True, sorted=True | |
) | |
# (batch_size, max_k) -> (max_k, batch_size). | |
top_max_k_inds = top_max_k_inds.t() | |
# (batch_size, ) -> (max_k, batch_size). | |
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) | |
# (i, j) = 1 if top i-th prediction for the j-th sample is correct. | |
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) | |
# Compute the number of topk correct predictions for each k. | |
topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks] | |
return topks_correct | |
def topk_errors(preds, labels, ks): | |
""" | |
Computes the top-k error for each k. | |
Args: | |
preds (array): array of predictions. Dimension is N. | |
labels (array): array of labels. Dimension is N. | |
ks (list): list of ks to calculate the top accuracies. | |
""" | |
num_topks_correct = topks_correct(preds, labels, ks) | |
return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] | |
def topk_accuracies(preds, labels, ks): | |
""" | |
Computes the top-k accuracy for each k. | |
Args: | |
preds (array): array of predictions. Dimension is N. | |
labels (array): array of labels. Dimension is N. | |
ks (list): list of ks to calculate the top accuracies. | |
""" | |
num_topks_correct = topks_correct(preds, labels, ks) | |
return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] | |
def multitask_topks_correct(preds, labels, ks=(1,)): | |
""" | |
Args: | |
preds: tuple(torch.FloatTensor), each tensor should be of shape | |
[batch_size, class_count], class_count can vary on a per task basis, i.e. | |
outputs[i].shape[1] can be different to outputs[j].shape[j]. | |
labels: tuple(torch.LongTensor), each tensor should be of shape [batch_size] | |
ks: tuple(int), compute accuracy at top-k for the values of k specified | |
in this parameter. | |
Returns: | |
tuple(float), same length at topk with the corresponding accuracy@k in. | |
""" | |
max_k = int(np.max(ks)) | |
task_count = len(preds) | |
batch_size = labels[0].size(0) | |
all_correct = torch.zeros(max_k, batch_size).type(torch.ByteTensor) | |
if torch.cuda.is_available(): | |
all_correct = all_correct.cuda() | |
for output, label in zip(preds, labels): | |
_, max_k_idx = output.topk(max_k, dim=1, largest=True, sorted=True) | |
# Flip batch_size, class_count as .view doesn't work on non-contiguous | |
max_k_idx = max_k_idx.t() | |
correct_for_task = max_k_idx.eq(label.view(1, -1).expand_as(max_k_idx)) | |
all_correct.add_(correct_for_task) | |
multitask_topks_correct = [ | |
torch.ge(all_correct[:k].float().sum(0), task_count).float().sum(0) for k in ks | |
] | |
return multitask_topks_correct | |