|
import numpy as np |
|
import torch |
|
|
|
from evaluation.helpers import softmax_feature_maps |
|
|
|
|
|
class MultiKCrossChannelMaxPooledSum: |
|
def __init__(self, top_k_range, weights, interactions, func="softmax"): |
|
self.top_k_range = top_k_range |
|
self.weights = weights |
|
self.failed = False |
|
self.max_ks = self.get_max_ks(weights) |
|
self.locality_of_used_features = torch.zeros(len(top_k_range), device=weights.device) |
|
self.locality_of_exclusely_used_features = torch.zeros(len(top_k_range), device=weights.device) |
|
self.ns_k = torch.zeros(len(top_k_range), device=weights.device) |
|
self.exclusive_ns = torch.zeros(len(top_k_range), device=weights.device) |
|
self.interactions = interactions |
|
self.func = func |
|
|
|
def get_max_ks(self, weights): |
|
nonzeros = torch.count_nonzero(torch.tensor(weights), 1) |
|
return nonzeros |
|
|
|
def get_top_n_locality(self, outputs, initial_feature_maps, k): |
|
feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, |
|
initial_feature_maps) |
|
max_ks = self.max_ks[top_classes] |
|
max_k_based_row_selection = max_ks >= k |
|
|
|
result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps, |
|
separated=True) |
|
return result |
|
|
|
def get_locality(self, outputs, initial_feature_maps, n): |
|
answer = self.get_top_n_locality(outputs, initial_feature_maps, n) |
|
return answer |
|
|
|
def get_result(self): |
|
|
|
|
|
|
|
|
|
exclusive_array = torch.zeros_like(self.locality_of_exclusely_used_features) |
|
local_array = torch.zeros_like(self.locality_of_used_features) |
|
|
|
|
|
cumulated = torch.cumsum(self.exclusive_ns, 0) |
|
end_idx = torch.argmax(cumulated) |
|
exclusivity_array = self.locality_of_exclusely_used_features[:end_idx + 1] / self.exclusive_ns[:end_idx + 1] |
|
exclusivity_array[exclusivity_array != exclusivity_array] = 0 |
|
exclusive_array[:len(exclusivity_array)] = exclusivity_array |
|
locality_array = self.locality_of_used_features[self.locality_of_used_features != 0] / self.ns_k[ |
|
self.locality_of_used_features != 0] |
|
local_array[:len(locality_array)] = locality_array |
|
return local_array, exclusive_array |
|
|
|
def get_crosspooled(self, relevant_weights, mask, k, vector_size, feature_maps, separated=False): |
|
relevant_indices = get_relevant_indices(relevant_weights, k)[mask] |
|
|
|
indices = relevant_indices.unsqueeze(2).repeat(1, 1, vector_size) |
|
sub_feature_maps = torch.gather(feature_maps[mask], 1, indices) |
|
|
|
cross_pooled = torch.max(sub_feature_maps, 1)[0] |
|
if separated: |
|
return torch.sum(cross_pooled, 1) / k |
|
else: |
|
ns = len(cross_pooled) |
|
result = torch.sum(cross_pooled) / (k) |
|
|
|
|
|
return ns, result |
|
|
|
def adapt_feature_maps(self, outputs, initial_feature_maps): |
|
if self.func == "softmax": |
|
feature_maps = softmax_feature_maps(initial_feature_maps) |
|
feature_maps = torch.flatten(feature_maps, 2) |
|
vector_size = feature_maps.shape[2] |
|
top_classes = torch.argmax(outputs, dim=1) |
|
relevant_weights = self.weights[top_classes] |
|
if relevant_weights.shape[1] != feature_maps.shape[1]: |
|
feature_maps = self.interactions.get_localized_features(initial_feature_maps) |
|
feature_maps = softmax_feature_maps(feature_maps) |
|
feature_maps = torch.flatten(feature_maps, 2) |
|
return feature_maps, relevant_weights, vector_size, top_classes |
|
|
|
def calculate_locality(self, outputs, initial_feature_maps): |
|
feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, |
|
initial_feature_maps) |
|
max_ks = self.max_ks[top_classes] |
|
for k in self.top_k_range: |
|
|
|
max_k_based_row_selection = max_ks >= k |
|
if torch.sum(max_k_based_row_selection) == 0: |
|
break |
|
|
|
exclusive_k = max_ks == k |
|
if torch.sum(exclusive_k) != 0: |
|
ns, result = self.get_crosspooled(relevant_weights, exclusive_k, k, vector_size, feature_maps) |
|
self.locality_of_exclusely_used_features[k - 1] += result |
|
self.exclusive_ns[k - 1] += ns |
|
ns, result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps) |
|
self.ns_k[k - 1] += ns |
|
self.locality_of_used_features[k - 1] += result |
|
|
|
def __call__(self, outputs, initial_feature_maps): |
|
self.calculate_locality(outputs, initial_feature_maps) |
|
|
|
|
|
def get_relevant_indices(weights, top_k): |
|
top_k = weights.topk(top_k)[1] |
|
return top_k |