|
import torch |
|
from torch import nn |
|
|
|
from architectures.SLDDLevel import SLDDLevel |
|
|
|
|
|
class FinalLayer(): |
|
def __init__(self, num_classes, n_features): |
|
super().__init__() |
|
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
|
self.linear = nn.Linear(n_features, num_classes) |
|
self.featureDropout = torch.nn.Dropout(0.2) |
|
self.selection = None |
|
|
|
def transform_output(self, feature_maps, with_feature_maps, |
|
with_final_features): |
|
if self.selection is not None: |
|
feature_maps = feature_maps[:, self.selection] |
|
x = self.avgpool(feature_maps) |
|
pre_out = torch.flatten(x, 1) |
|
final_features = self.featureDropout(pre_out) |
|
final = self.linear(final_features) |
|
final = [final] |
|
if with_feature_maps: |
|
final.append(feature_maps) |
|
if with_final_features: |
|
final.append(final_features) |
|
if len(final) == 1: |
|
final = final[0] |
|
return final |
|
|
|
|
|
def set_model_sldd(self, selection, weight, mean, std, bias = None): |
|
self.selection = selection |
|
self.linear = SLDDLevel(selection, weight, mean, std, bias) |
|
self.featureDropout = torch.nn.Dropout(0.1) |