|
import torch |
|
from torch import nn |
|
|
|
""" |
|
Feature Diversity Loss: |
|
Usage to replicate paper: |
|
Call |
|
loss_function = FeatureDiversityLoss(0.196, linear) |
|
to inititalize loss with linear layer of model. |
|
At each mini batch get feature maps (Output of final convolutional layer) and add to Loss: |
|
loss += loss_function(feature_maps, outputs) |
|
""" |
|
|
|
|
|
class FeatureDiversityLoss(nn.Module): |
|
def __init__(self, scaling_factor, linear): |
|
super().__init__() |
|
self.scaling_factor = scaling_factor |
|
print("Scaling Factor: ", self.scaling_factor) |
|
self.linearLayer = linear |
|
|
|
def initialize(self, linearLayer): |
|
self.linearLayer = linearLayer |
|
|
|
def get_weights(self, outputs): |
|
weight_matrix = self.linearLayer.weight |
|
weight_matrix = torch.abs(weight_matrix) |
|
top_classes = torch.argmax(outputs, dim=1) |
|
relevant_weights = weight_matrix[top_classes] |
|
return relevant_weights |
|
|
|
def forward(self, feature_maps, outputs): |
|
relevant_weights = self.get_weights(outputs) |
|
relevant_weights = norm_vector(relevant_weights) |
|
feature_maps = preserve_avg_func(feature_maps) |
|
flattened_feature_maps = feature_maps.flatten(2) |
|
batch, features, map_size = flattened_feature_maps.size() |
|
relevant_feature_maps = flattened_feature_maps * relevant_weights[..., None] |
|
diversity_loss = torch.sum( |
|
torch.amax(relevant_feature_maps, dim=1)) |
|
return -diversity_loss / batch * self.scaling_factor |
|
|
|
|
|
def norm_vector(x): |
|
return x / (torch.norm(x, dim=1) + 1e-5)[:, None] |
|
|
|
|
|
def preserve_avg_func(x): |
|
avgs = torch.mean(x, dim=[2, 3]) |
|
max_avgs = torch.max(avgs, dim=1)[0] |
|
scaling_factor = avgs / torch.clamp(max_avgs[..., None], min=1e-6) |
|
softmaxed_maps = softmax_feature_maps(x) |
|
scaled_maps = softmaxed_maps * scaling_factor[..., None, None] |
|
return scaled_maps |
|
|
|
|
|
def softmax_feature_maps(x): |
|
return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x) |
|
|
|
|