import torch def accuracy(outputs, labels): _, preds = torch.max(outputs, 1) return torch.sum(preds == labels).item() / len(labels)