Spaces:
Sleeping
Sleeping
import torch | |
def compute_contribution_top_feature(features, outputs, weights, labels): | |
with torch.no_grad(): | |
total_pre_softmax, predicted_classes = torch.max(outputs, dim=1) | |
feature_part = features * weights.to(features.device)[predicted_classes] | |
class_specific_feature_part = torch.zeros((weights.shape[0], features.shape[1],)) | |
feature_class_part = torch.zeros((weights.shape[0], features.shape[1],)) | |
for unique_class in predicted_classes.unique(): | |
mask = predicted_classes == unique_class | |
class_specific_feature_part[unique_class] = feature_part[mask].mean(dim=0) | |
gt_mask = labels == unique_class | |
feature_class_part[unique_class] = feature_part[gt_mask].mean(dim=0) | |
abs_features = feature_part.abs() | |
abs_sum = abs_features.sum(dim=1) | |
fractions_abs = abs_features / abs_sum[:, None] | |
abs_max = fractions_abs.max(dim=1)[0] | |
mask = ~torch.isnan(abs_max) | |
abs_max = abs_max[mask] | |
return abs_max.mean() |