Haaribo's picture
Add application file
9b896f5
import torch
from torch import nn
from architectures.SLDDLevel import SLDDLevel
class FinalLayer():
def __init__(self, num_classes, n_features):
super().__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = nn.Linear(n_features, num_classes)
self.featureDropout = torch.nn.Dropout(0.2)
self.selection = None
def transform_output(self, feature_maps, with_feature_maps,
with_final_features):
if self.selection is not None:
feature_maps = feature_maps[:, self.selection]
x = self.avgpool(feature_maps)
pre_out = torch.flatten(x, 1)
final_features = self.featureDropout(pre_out)
final = self.linear(final_features)
final = [final]
if with_feature_maps:
final.append(feature_maps)
if with_final_features:
final.append(final_features)
if len(final) == 1:
final = final[0]
return final
def set_model_sldd(self, selection, weight, mean, std, bias = None):
self.selection = selection
self.linear = SLDDLevel(selection, weight, mean, std, bias)
self.featureDropout = torch.nn.Dropout(0.1)