pal-b-large-opt-350m / tensor_initializer.py
daiweichen's picture
Upload PAL_B_RM_opt
3a2aa34 verified
import numpy as np
import torch
class TensorInitializer:
def __init__(self, type: str, seed: int, is_expectation_norm_init: bool = False):
self.initializer_type = type
self.rng = np.random.default_rng(seed)
self.is_expectation_norm_init = is_expectation_norm_init
def gaussian_initializer(
self,
dim: int,
size: int,
) -> torch.Tensor:
mean = np.zeros(dim)
if self.is_expectation_norm_init:
# expectation normalization
cov = 1 / dim * np.eye(dim)
return torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
else:
# enforced normalization
cov = np.eye(dim)
unnorm_tensor = torch.tensor(self.rng.multivariate_normal(mean, cov, size), dtype=torch.float32)#.float()
return unnorm_tensor / torch.norm(unnorm_tensor, dim=1, keepdim=True)
def __call__(self, *args, **kwargs):
if self.initializer_type == 'gaussian':
return self.gaussian_initializer(*args, **kwargs)
else:
raise ValueError(f'Unknown initializer type: {self.initializer_type}')