import numpy as np import torch from evaluation.helpers import softmax_feature_maps class MultiKCrossChannelMaxPooledSum: def __init__(self, top_k_range, weights, interactions, func="softmax"): self.top_k_range = top_k_range self.weights = weights self.failed = False self.max_ks = self.get_max_ks(weights) self.locality_of_used_features = torch.zeros(len(top_k_range), device=weights.device) self.locality_of_exclusely_used_features = torch.zeros(len(top_k_range), device=weights.device) self.ns_k = torch.zeros(len(top_k_range), device=weights.device) self.exclusive_ns = torch.zeros(len(top_k_range), device=weights.device) self.interactions = interactions self.func = func def get_max_ks(self, weights): nonzeros = torch.count_nonzero(torch.tensor(weights), 1) return nonzeros def get_top_n_locality(self, outputs, initial_feature_maps, k): feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, initial_feature_maps) max_ks = self.max_ks[top_classes] max_k_based_row_selection = max_ks >= k result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps, separated=True) return result def get_locality(self, outputs, initial_feature_maps, n): answer = self.get_top_n_locality(outputs, initial_feature_maps, n) return answer def get_result(self): # if torch.sum(self.exclusive_ns) ==0: # end_idx = len(self.exclusive_ns) - 1 # else: exclusive_array = torch.zeros_like(self.locality_of_exclusely_used_features) local_array = torch.zeros_like(self.locality_of_used_features) # if self.failed: # return local_array, exclusive_array cumulated = torch.cumsum(self.exclusive_ns, 0) end_idx = torch.argmax(cumulated) exclusivity_array = self.locality_of_exclusely_used_features[:end_idx + 1] / self.exclusive_ns[:end_idx + 1] exclusivity_array[exclusivity_array != exclusivity_array] = 0 exclusive_array[:len(exclusivity_array)] = exclusivity_array locality_array = self.locality_of_used_features[self.locality_of_used_features != 0] / self.ns_k[ self.locality_of_used_features != 0] local_array[:len(locality_array)] = locality_array return local_array, exclusive_array def get_crosspooled(self, relevant_weights, mask, k, vector_size, feature_maps, separated=False): relevant_indices = get_relevant_indices(relevant_weights, k)[mask] # this should have size batch x k x featuremapsize squared] indices = relevant_indices.unsqueeze(2).repeat(1, 1, vector_size) sub_feature_maps = torch.gather(feature_maps[mask], 1, indices) # shape batch x featuremapsquared: For each "pixel" the highest value cross_pooled = torch.max(sub_feature_maps, 1)[0] if separated: return torch.sum(cross_pooled, 1) / k else: ns = len(cross_pooled) result = torch.sum(cross_pooled) / (k) # should be batch x map size return ns, result def adapt_feature_maps(self, outputs, initial_feature_maps): if self.func == "softmax": feature_maps = softmax_feature_maps(initial_feature_maps) feature_maps = torch.flatten(feature_maps, 2) vector_size = feature_maps.shape[2] top_classes = torch.argmax(outputs, dim=1) relevant_weights = self.weights[top_classes] if relevant_weights.shape[1] != feature_maps.shape[1]: feature_maps = self.interactions.get_localized_features(initial_feature_maps) feature_maps = softmax_feature_maps(feature_maps) feature_maps = torch.flatten(feature_maps, 2) return feature_maps, relevant_weights, vector_size, top_classes def calculate_locality(self, outputs, initial_feature_maps): feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, initial_feature_maps) max_ks = self.max_ks[top_classes] for k in self.top_k_range: # relevant_k_s = max_ks[] max_k_based_row_selection = max_ks >= k if torch.sum(max_k_based_row_selection) == 0: break exclusive_k = max_ks == k if torch.sum(exclusive_k) != 0: ns, result = self.get_crosspooled(relevant_weights, exclusive_k, k, vector_size, feature_maps) self.locality_of_exclusely_used_features[k - 1] += result self.exclusive_ns[k - 1] += ns ns, result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps) self.ns_k[k - 1] += ns self.locality_of_used_features[k - 1] += result def __call__(self, outputs, initial_feature_maps): self.calculate_locality(outputs, initial_feature_maps) def get_relevant_indices(weights, top_k): top_k = weights.topk(top_k)[1] return top_k