File size: 5,743 Bytes
55e1701 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
from abc import ABC
from megatron.core.transformer.transformer_config import TransformerConfig
from torch.nn.modules.module import _addindent
from termcolor import colored
def prehook_save_input_shape(func):
def wrapper(self, *input_shapes, **kw_input_shapes):
if len(input_shapes) + len(kw_input_shapes) == 0:
if "_input_shape" in self.__dict__:
return func(self, *self._input_shape, **self._kw_input_shapes)
else:
return 0
self._input_shape = input_shapes
self._kw_input_shapes = kw_input_shapes
return func(self, *self._input_shape, **self._kw_input_shapes)
return wrapper
class MetaBase(type):
def __new__(cls, name, bases, attrs):
if "num_activation" in attrs:
attrs["num_activation"] = prehook_save_input_shape(attrs["num_activation"])
return super().__new__(cls, name, bases, attrs)
class MemEstimator(metaclass=MetaBase):
def __init__(self, *args, **kwargs):
self._modules = {}
pass
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
# extra_repr = self.extra_repr()
# # empty string will be split into list ['']
# if extra_repr:
# extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
stat = (
"\t/* n_params="
+ colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
+ "\tn_act="
+ colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
+ " */"
)
main_str = self._get_name() + stat + " ("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
return f"{self.__class__.__name__} n_param={self.num_parameter()}"
def dump(self):
ret = {}
ret['name'] = self._get_name()
ret['n_params'] = self.num_parameter()
ret['n_act'] = self.num_activation()
modules = {}
for key, module in self._modules.items():
modules[key] = module.dump()
if len(modules)>0:
ret['modules'] = modules
return ret
def _get_name(self):
return self.__class__.__name__
def num_parameter(self):
"""
Calculate number of the model parameters
"""
raise NotImplemented
def num_activation(self, input_shape: list[int]):
"""
Calculate number of the activation with given input_shape.
Args:
input shape
"""
raise NotImplemented
def mock_forward(self, input_shape: list[int]):
"""
Mock the forward.
Args:
input shape
return:
output shape
"""
raise NotImplemented
def __setattr__(self, name: str, value) -> None:
if isinstance(value, MemEstimator):
modules = self.__dict__.get("_modules")
modules[name] = value
else:
pass
return super().__setattr__(name, value)
def __delattr__(self, name):
modules = self.__dict__.get("_modules")
if name in modules:
del modules[name]
return super().__delattr__(name)
_global_config: TransformerConfig = None
def set_global_config(cfg):
global _global_config
_global_config = cfg
def get_tensor_model_parallel_world_size():
global _global_config
return _global_config.tensor_model_parallel_size
def get_tensor_model_parallel_rank():
return 0
def get_expert_tensor_parallel_world_size():
global _global_config
return _global_config.expert_tensor_parallel_size
def get_expert_tensor_parallel_rank():
return 0
_pp_rank = 0
def set_pipeline_model_parallel_rank(rank):
global _pp_rank
_pp_rank = rank
def get_pipeline_model_parallel_rank():
global _pp_rank
return _pp_rank
def get_virtual_pipeline_model_parallel_rank():
return 0
def get_pipeline_model_parallel_world_size():
global _global_config
return _global_config.pipeline_model_parallel_size
def get_expert_model_parallel_rank():
return 0
def get_expert_model_parallel_world_size():
global _global_config
return _global_config.expert_model_parallel_size
def get_virtual_pipeline_model_parallel_world_size():
global _global_config
return _global_config.virtual_pipeline_model_parallel_size
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if (
get_virtual_pipeline_model_parallel_world_size() is not None
and get_virtual_pipeline_model_parallel_rank() != 0
):
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline-model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1
)
def cum_mul(l: list):
try:
ret = 1
for one in l:
ret *= one
return ret
except:
return 0
__import__('ipdb').set_trace()
|