Haaribo's picture
Add application file
9b896f5
raw
history blame contribute delete
457 Bytes
import torch
class NormalizedRepresentation(torch.nn.Module):
def __init__(self, loader, metadata, device='cuda', tol=1e-5):
super(NormalizedRepresentation, self).__init__()
assert metadata is not None
self.device = device
self.mu = metadata['X']['mean']
self.sigma = torch.clamp(metadata['X']['std'], tol)
def forward(self, X):
return (X - self.mu.to(self.device)) / self.sigma.to(self.device)