File size: 1,210 Bytes
3a2aa34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import torch

class TensorInitializer:

    def __init__(self, type: str, seed: int, is_expectation_norm_init: bool = False):
        self.initializer_type = type
        self.rng = np.random.default_rng(seed)
        self.is_expectation_norm_init = is_expectation_norm_init

    def gaussian_initializer(
            self,
            dim: int,
            size: int,
        ) -> torch.Tensor:

        mean = np.zeros(dim)
        if self.is_expectation_norm_init:
            # expectation normalization
            cov = 1 / dim * np.eye(dim) 
            return torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
        else:   
            # enforced normalization
            cov = np.eye(dim)
            unnorm_tensor = torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
            return unnorm_tensor / torch.norm(unnorm_tensor, dim=1, keepdim=True)

    def __call__(self, *args, **kwargs):
        if self.initializer_type == 'gaussian':
            return self.gaussian_initializer(*args, **kwargs)
        else:
            raise ValueError(f'Unknown initializer type: {self.initializer_type}')