File size: 1,210 Bytes
3a2aa34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import numpy as np
import torch
class TensorInitializer:
def __init__(self, type: str, seed: int, is_expectation_norm_init: bool = False):
self.initializer_type = type
self.rng = np.random.default_rng(seed)
self.is_expectation_norm_init = is_expectation_norm_init
def gaussian_initializer(
self,
dim: int,
size: int,
) -> torch.Tensor:
mean = np.zeros(dim)
if self.is_expectation_norm_init:
# expectation normalization
cov = 1 / dim * np.eye(dim)
return torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
else:
# enforced normalization
cov = np.eye(dim)
unnorm_tensor = torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
return unnorm_tensor / torch.norm(unnorm_tensor, dim=1, keepdim=True)
def __call__(self, *args, **kwargs):
if self.initializer_type == 'gaussian':
return self.gaussian_initializer(*args, **kwargs)
else:
raise ValueError(f'Unknown initializer type: {self.initializer_type}')
|