|
import numpy as np |
|
import torch |
|
|
|
class TensorInitializer: |
|
|
|
def __init__(self, type: str, seed: int, is_expectation_norm_init: bool = False): |
|
self.initializer_type = type |
|
self.rng = np.random.default_rng(seed) |
|
self.is_expectation_norm_init = is_expectation_norm_init |
|
|
|
def gaussian_initializer( |
|
self, |
|
dim: int, |
|
size: int, |
|
) -> torch.Tensor: |
|
|
|
mean = np.zeros(dim) |
|
if self.is_expectation_norm_init: |
|
|
|
cov = 1 / dim * np.eye(dim) |
|
return torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32) |
|
else: |
|
|
|
cov = np.eye(dim) |
|
unnorm_tensor = torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32) |
|
return unnorm_tensor / torch.norm(unnorm_tensor, dim=1, keepdim=True) |
|
|
|
def __call__(self, *args, **kwargs): |
|
if self.initializer_type == 'gaussian': |
|
return self.gaussian_initializer(*args, **kwargs) |
|
else: |
|
raise ValueError(f'Unknown initializer type: {self.initializer_type}') |
|
|