File size: 457 Bytes
9b896f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

import torch


class NormalizedRepresentation(torch.nn.Module):
    def __init__(self, loader, metadata, device='cuda', tol=1e-5):
        super(NormalizedRepresentation, self).__init__()

        assert metadata is not None
        self.device = device
        self.mu = metadata['X']['mean']
        self.sigma = torch.clamp(metadata['X']['std'], tol)

    def forward(self, X):
        return (X - self.mu.to(self.device)) / self.sigma.to(self.device)