Yan Bai
add
55e1701
from abc import ABC
from megatron.core.transformer.transformer_config import TransformerConfig
from torch.nn.modules.module import _addindent
from termcolor import colored
def prehook_save_input_shape(func):
def wrapper(self, *input_shapes, **kw_input_shapes):
if len(input_shapes) + len(kw_input_shapes) == 0:
if "_input_shape" in self.__dict__:
return func(self, *self._input_shape, **self._kw_input_shapes)
else:
return 0
self._input_shape = input_shapes
self._kw_input_shapes = kw_input_shapes
return func(self, *self._input_shape, **self._kw_input_shapes)
return wrapper
class MetaBase(type):
def __new__(cls, name, bases, attrs):
if "num_activation" in attrs:
attrs["num_activation"] = prehook_save_input_shape(attrs["num_activation"])
return super().__new__(cls, name, bases, attrs)
class MemEstimator(metaclass=MetaBase):
def __init__(self, *args, **kwargs):
self._modules = {}
pass
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
# extra_repr = self.extra_repr()
# # empty string will be split into list ['']
# if extra_repr:
# extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
stat = (
"\t/* n_params="
+ colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
+ "\tn_act="
+ colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
+ " */"
)
main_str = self._get_name() + stat + " ("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
return f"{self.__class__.__name__} n_param={self.num_parameter()}"
def dump(self):
ret = {}
ret['name'] = self._get_name()
ret['n_params'] = self.num_parameter()
ret['n_act'] = self.num_activation()
modules = {}
for key, module in self._modules.items():
modules[key] = module.dump()
if len(modules)>0:
ret['modules'] = modules
return ret
def _get_name(self):
return self.__class__.__name__
def num_parameter(self):
"""
Calculate number of the model parameters
"""
raise NotImplemented
def num_activation(self, input_shape: list[int]):
"""
Calculate number of the activation with given input_shape.
Args:
input shape
"""
raise NotImplemented
def mock_forward(self, input_shape: list[int]):
"""
Mock the forward.
Args:
input shape
return:
output shape
"""
raise NotImplemented
def __setattr__(self, name: str, value) -> None:
if isinstance(value, MemEstimator):
modules = self.__dict__.get("_modules")
modules[name] = value
else:
pass
return super().__setattr__(name, value)
def __delattr__(self, name):
modules = self.__dict__.get("_modules")
if name in modules:
del modules[name]
return super().__delattr__(name)
_global_config: TransformerConfig = None
def set_global_config(cfg):
global _global_config
_global_config = cfg
def get_tensor_model_parallel_world_size():
global _global_config
return _global_config.tensor_model_parallel_size
def get_tensor_model_parallel_rank():
return 0
def get_expert_tensor_parallel_world_size():
global _global_config
return _global_config.expert_tensor_parallel_size
def get_expert_tensor_parallel_rank():
return 0
_pp_rank = 0
def set_pipeline_model_parallel_rank(rank):
global _pp_rank
_pp_rank = rank
def get_pipeline_model_parallel_rank():
global _pp_rank
return _pp_rank
def get_virtual_pipeline_model_parallel_rank():
return 0
def get_pipeline_model_parallel_world_size():
global _global_config
return _global_config.pipeline_model_parallel_size
def get_expert_model_parallel_rank():
return 0
def get_expert_model_parallel_world_size():
global _global_config
return _global_config.expert_model_parallel_size
def get_virtual_pipeline_model_parallel_world_size():
global _global_config
return _global_config.virtual_pipeline_model_parallel_size
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if (
get_virtual_pipeline_model_parallel_world_size() is not None
and get_virtual_pipeline_model_parallel_rank() != 0
):
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline-model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1
)
def cum_mul(l: list):
try:
ret = 1
for one in l:
ret *= one
return ret
except:
return 0
__import__('ipdb').set_trace()