import torch from torch import nn from architectures.SLDDLevel import SLDDLevel class FinalLayer(): def __init__(self, num_classes, n_features): super().__init__() self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.linear = nn.Linear(n_features, num_classes) self.featureDropout = torch.nn.Dropout(0.2) self.selection = None def transform_output(self, feature_maps, with_feature_maps, with_final_features): if self.selection is not None: feature_maps = feature_maps[:, self.selection] x = self.avgpool(feature_maps) pre_out = torch.flatten(x, 1) final_features = self.featureDropout(pre_out) final = self.linear(final_features) final = [final] if with_feature_maps: final.append(feature_maps) if with_final_features: final.append(final_features) if len(final) == 1: final = final[0] return final def set_model_sldd(self, selection, weight, mean, std, bias = None): self.selection = selection self.linear = SLDDLevel(selection, weight, mean, std, bias) self.featureDropout = torch.nn.Dropout(0.1)