import torch def mic_acc_cal(preds, labels): if isinstance(labels, tuple): assert len(labels) == 3 targets_a, targets_b, lam = labels acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \ + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds) else: acc_mic_top1 = (preds == labels).sum().item() / len(labels) return acc_mic_top1