|
import torch.nn |
|
|
|
|
|
class SLDDLevel(torch.nn.Module): |
|
def __init__(self, selection, weight_at_selection,mean, std, bias=None): |
|
super().__init__() |
|
self.register_buffer('selection', torch.tensor(selection, dtype=torch.long)) |
|
num_classes, n_features = weight_at_selection.shape |
|
selected_mean = mean |
|
selected_std = std |
|
if len(selected_mean) != len(selection): |
|
selected_mean = selected_mean[selection] |
|
selected_std = selected_std[selection] |
|
self.mean = torch.nn.Parameter(selected_mean) |
|
self.std = torch.nn.Parameter(selected_std) |
|
if bias is not None: |
|
self.layer = torch.nn.Linear(n_features, num_classes) |
|
self.layer.bias = torch.nn.Parameter(bias, requires_grad=False) |
|
else: |
|
self.layer = torch.nn.Linear(n_features, num_classes, bias=False) |
|
self.layer.weight = torch.nn.Parameter(weight_at_selection, requires_grad=False) |
|
|
|
@property |
|
def weight(self): |
|
return self.layer.weight |
|
|
|
@property |
|
def bias(self): |
|
if self.layer.bias is None: |
|
return torch.zeros(self.layer.out_features) |
|
else: |
|
return self.layer.bias |
|
|
|
|
|
def forward(self, input): |
|
input = (input - self.mean) / torch.clamp(self.std, min=1e-6) |
|
return self.layer(input) |
|
|