File size: 1,236 Bytes
9b896f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn

from architectures.SLDDLevel import SLDDLevel


class FinalLayer():
    def __init__(self, num_classes,  n_features):
        super().__init__()
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(n_features, num_classes)
        self.featureDropout = torch.nn.Dropout(0.2)
        self.selection = None

    def transform_output(self,  feature_maps, with_feature_maps,
                         with_final_features):
        if self.selection is not None:
            feature_maps = feature_maps[:, self.selection]
        x = self.avgpool(feature_maps)
        pre_out = torch.flatten(x, 1)
        final_features = self.featureDropout(pre_out)
        final = self.linear(final_features)
        final = [final]
        if with_feature_maps:
            final.append(feature_maps)
        if with_final_features:
            final.append(final_features)
        if len(final) == 1:
            final = final[0]
        return final


    def set_model_sldd(self, selection, weight, mean, std, bias = None):
        self.selection = selection
        self.linear = SLDDLevel(selection, weight, mean, std, bias)
        self.featureDropout = torch.nn.Dropout(0.1)