dartpain commited on
Commit
52794a7
·
1 Parent(s): 6c6979d

Delete param_init_fns.py

Browse files
Files changed (1) hide show
  1. param_init_fns.py +0 -181
param_init_fns.py DELETED
@@ -1,181 +0,0 @@
1
- import math
2
- import warnings
3
- from collections.abc import Sequence
4
- from functools import partial
5
- from typing import Optional, Tuple, Union
6
- import torch
7
- from torch import nn
8
- from .norm import NORM_CLASS_REGISTRY
9
-
10
- def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
11
- del kwargs
12
- if verbose > 1:
13
- warnings.warn(f"Initializing network using module's reset_parameters attribute")
14
- if hasattr(module, 'reset_parameters'):
15
- module.reset_parameters()
16
-
17
- def fused_init_helper_(module: nn.Module, init_fn_):
18
- _fused = getattr(module, '_fused', None)
19
- if _fused is None:
20
- raise RuntimeError(f'Internal logic error')
21
- (dim, splits) = _fused
22
- splits = (0, *splits, module.weight.size(dim))
23
- for (s, e) in zip(splits[:-1], splits[1:]):
24
- slice_indices = [slice(None)] * module.weight.ndim
25
- slice_indices[dim] = slice(s, e)
26
- init_fn_(module.weight[slice_indices])
27
-
28
- def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
29
- del kwargs
30
- if verbose > 1:
31
- warnings.warn(f'If model has bias parameters they are initialized to 0.')
32
- init_div_is_residual = init_div_is_residual
33
- if init_div_is_residual is False:
34
- div_is_residual = 1.0
35
- elif init_div_is_residual is True:
36
- div_is_residual = math.sqrt(2 * n_layers)
37
- elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
38
- div_is_residual = init_div_is_residual
39
- elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40
- div_is_residual = float(init_div_is_residual)
41
- else:
42
- div_is_residual = 1.0
43
- raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
44
- if init_div_is_residual is not False:
45
- if verbose > 1:
46
- warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
47
- if isinstance(module, nn.Linear):
48
- if hasattr(module, '_fused'):
49
- fused_init_helper_(module, init_fn_)
50
- else:
51
- init_fn_(module.weight)
52
- if module.bias is not None:
53
- torch.nn.init.zeros_(module.bias)
54
- if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
- with torch.no_grad():
56
- module.weight.div_(div_is_residual)
57
- elif isinstance(module, nn.Embedding):
58
- if emb_init_std is not None:
59
- std = emb_init_std
60
- if std == 0:
61
- warnings.warn(f'Embedding layer initialized to 0.')
62
- emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
- if verbose > 1:
64
- warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
65
- elif emb_init_uniform_lim is not None:
66
- lim = emb_init_uniform_lim
67
- if isinstance(lim, Sequence):
68
- if len(lim) > 2:
69
- raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
70
- if lim[0] == lim[1]:
71
- warnings.warn(f'Embedding layer initialized to {lim[0]}.')
72
- else:
73
- if lim == 0:
74
- warnings.warn(f'Embedding layer initialized to 0.')
75
- lim = [-lim, lim]
76
- (a, b) = lim
77
- emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78
- if verbose > 1:
79
- warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
80
- else:
81
- emb_init_fn_ = init_fn_
82
- emb_init_fn_(module.weight)
83
- elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84
- if verbose > 1:
85
- warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86
- if hasattr(module, 'weight') and module.weight is not None:
87
- torch.nn.init.ones_(module.weight)
88
- if hasattr(module, 'bias') and module.bias is not None:
89
- torch.nn.init.zeros_(module.bias)
90
- elif isinstance(module, nn.MultiheadAttention):
91
- if module._qkv_same_embed_dim:
92
- assert module.in_proj_weight is not None
93
- assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
94
- assert d_model is not None
95
- _d = d_model
96
- splits = (0, _d, 2 * _d, 3 * _d)
97
- for (s, e) in zip(splits[:-1], splits[1:]):
98
- init_fn_(module.in_proj_weight[s:e])
99
- else:
100
- assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
101
- assert module.in_proj_weight is None
102
- init_fn_(module.q_proj_weight)
103
- init_fn_(module.k_proj_weight)
104
- init_fn_(module.v_proj_weight)
105
- if module.in_proj_bias is not None:
106
- torch.nn.init.zeros_(module.in_proj_bias)
107
- if module.bias_k is not None:
108
- torch.nn.init.zeros_(module.bias_k)
109
- if module.bias_v is not None:
110
- torch.nn.init.zeros_(module.bias_v)
111
- init_fn_(module.out_proj.weight)
112
- if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
113
- with torch.no_grad():
114
- module.out_proj.weight.div_(div_is_residual)
115
- if module.out_proj.bias is not None:
116
- torch.nn.init.zeros_(module.out_proj.bias)
117
- else:
118
- for _ in module.parameters(recurse=False):
119
- raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
120
-
121
- def _normal_init_(std, mean=0.0):
122
- return partial(torch.nn.init.normal_, mean=mean, std=std)
123
-
124
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
125
- del kwargs
126
- init_fn_ = _normal_init_(std=std)
127
- if verbose > 1:
128
- warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
130
-
131
- def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
132
- del kwargs
133
- if init_std is None:
134
- raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
136
-
137
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
138
- del kwargs
139
- std = math.sqrt(2 / (5 * d_model))
140
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
141
-
142
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
143
- """From section 2.3.1 of GPT-NeoX-20B:
144
-
145
- An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
146
- see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
147
- and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
148
- """
149
- del kwargs
150
- residual_div = n_layers / math.sqrt(10)
151
- if verbose > 1:
152
- warnings.warn(f'setting init_div_is_residual to {residual_div}')
153
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
154
-
155
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
156
- del kwargs
157
- if verbose > 1:
158
- warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159
- kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
161
-
162
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
163
- del kwargs
164
- if verbose > 1:
165
- warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166
- kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
168
-
169
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
170
- del kwargs
171
- xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172
- if verbose > 1:
173
- warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
175
-
176
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
177
- xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
- if verbose > 1:
179
- warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181
- MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}