import torch def compute_contribution_top_feature(features, outputs, weights, labels): with torch.no_grad(): total_pre_softmax, predicted_classes = torch.max(outputs, dim=1) feature_part = features * weights.to(features.device)[predicted_classes] class_specific_feature_part = torch.zeros((weights.shape[0], features.shape[1],)) feature_class_part = torch.zeros((weights.shape[0], features.shape[1],)) for unique_class in predicted_classes.unique(): mask = predicted_classes == unique_class class_specific_feature_part[unique_class] = feature_part[mask].mean(dim=0) gt_mask = labels == unique_class feature_class_part[unique_class] = feature_part[gt_mask].mean(dim=0) abs_features = feature_part.abs() abs_sum = abs_features.sum(dim=1) fractions_abs = abs_features / abs_sum[:, None] abs_max = fractions_abs.max(dim=1)[0] mask = ~torch.isnan(abs_max) abs_max = abs_max[mask] return abs_max.mean()