|
from abc import ABC |
|
|
|
from megatron.core.transformer.transformer_config import TransformerConfig |
|
from torch.nn.modules.module import _addindent |
|
from termcolor import colored |
|
|
|
|
|
def prehook_save_input_shape(func): |
|
def wrapper(self, *input_shapes, **kw_input_shapes): |
|
if len(input_shapes) + len(kw_input_shapes) == 0: |
|
if "_input_shape" in self.__dict__: |
|
return func(self, *self._input_shape, **self._kw_input_shapes) |
|
else: |
|
return 0 |
|
self._input_shape = input_shapes |
|
self._kw_input_shapes = kw_input_shapes |
|
return func(self, *self._input_shape, **self._kw_input_shapes) |
|
|
|
return wrapper |
|
|
|
|
|
class MetaBase(type): |
|
def __new__(cls, name, bases, attrs): |
|
if "num_activation" in attrs: |
|
attrs["num_activation"] = prehook_save_input_shape(attrs["num_activation"]) |
|
|
|
return super().__new__(cls, name, bases, attrs) |
|
|
|
|
|
class MemEstimator(metaclass=MetaBase): |
|
def __init__(self, *args, **kwargs): |
|
self._modules = {} |
|
pass |
|
|
|
def __repr__(self): |
|
|
|
extra_lines = [] |
|
|
|
|
|
|
|
|
|
child_lines = [] |
|
for key, module in self._modules.items(): |
|
mod_str = repr(module) |
|
mod_str = _addindent(mod_str, 2) |
|
child_lines.append("(" + key + "): " + mod_str) |
|
lines = extra_lines + child_lines |
|
|
|
stat = ( |
|
"\t/* n_params=" |
|
+ colored(f"{self.num_parameter()/1024/1024:.2f}M", "red") |
|
+ "\tn_act=" |
|
+ colored(f"{self.num_activation()/1024/1024:.2f}M", "green") |
|
+ " */" |
|
) |
|
main_str = self._get_name() + stat + " (" |
|
if lines: |
|
|
|
if len(extra_lines) == 1 and not child_lines: |
|
main_str += extra_lines[0] |
|
else: |
|
main_str += "\n " + "\n ".join(lines) + "\n" |
|
|
|
main_str += ")" |
|
return main_str |
|
return f"{self.__class__.__name__} n_param={self.num_parameter()}" |
|
|
|
def dump(self): |
|
ret = {} |
|
ret['name'] = self._get_name() |
|
ret['n_params'] = self.num_parameter() |
|
ret['n_act'] = self.num_activation() |
|
modules = {} |
|
for key, module in self._modules.items(): |
|
modules[key] = module.dump() |
|
if len(modules)>0: |
|
ret['modules'] = modules |
|
return ret |
|
|
|
|
|
def _get_name(self): |
|
return self.__class__.__name__ |
|
|
|
def num_parameter(self): |
|
""" |
|
Calculate number of the model parameters |
|
""" |
|
raise NotImplemented |
|
|
|
def num_activation(self, input_shape: list[int]): |
|
""" |
|
Calculate number of the activation with given input_shape. |
|
Args: |
|
input shape |
|
""" |
|
raise NotImplemented |
|
|
|
def mock_forward(self, input_shape: list[int]): |
|
""" |
|
Mock the forward. |
|
Args: |
|
input shape |
|
return: |
|
output shape |
|
""" |
|
raise NotImplemented |
|
|
|
def __setattr__(self, name: str, value) -> None: |
|
if isinstance(value, MemEstimator): |
|
modules = self.__dict__.get("_modules") |
|
modules[name] = value |
|
else: |
|
pass |
|
return super().__setattr__(name, value) |
|
|
|
def __delattr__(self, name): |
|
modules = self.__dict__.get("_modules") |
|
if name in modules: |
|
del modules[name] |
|
return super().__delattr__(name) |
|
|
|
|
|
_global_config: TransformerConfig = None |
|
|
|
|
|
def set_global_config(cfg): |
|
global _global_config |
|
_global_config = cfg |
|
|
|
|
|
def get_tensor_model_parallel_world_size(): |
|
global _global_config |
|
return _global_config.tensor_model_parallel_size |
|
|
|
|
|
def get_tensor_model_parallel_rank(): |
|
return 0 |
|
|
|
|
|
def get_expert_tensor_parallel_world_size(): |
|
global _global_config |
|
return _global_config.expert_tensor_parallel_size |
|
|
|
|
|
def get_expert_tensor_parallel_rank(): |
|
return 0 |
|
|
|
|
|
_pp_rank = 0 |
|
|
|
|
|
def set_pipeline_model_parallel_rank(rank): |
|
global _pp_rank |
|
_pp_rank = rank |
|
|
|
|
|
def get_pipeline_model_parallel_rank(): |
|
global _pp_rank |
|
return _pp_rank |
|
|
|
|
|
def get_virtual_pipeline_model_parallel_rank(): |
|
return 0 |
|
|
|
|
|
def get_pipeline_model_parallel_world_size(): |
|
global _global_config |
|
return _global_config.pipeline_model_parallel_size |
|
|
|
|
|
def get_expert_model_parallel_rank(): |
|
return 0 |
|
|
|
|
|
def get_expert_model_parallel_world_size(): |
|
global _global_config |
|
return _global_config.expert_model_parallel_size |
|
|
|
|
|
def get_virtual_pipeline_model_parallel_world_size(): |
|
global _global_config |
|
return _global_config.virtual_pipeline_model_parallel_size |
|
|
|
|
|
def is_pipeline_first_stage(ignore_virtual=False): |
|
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" |
|
if not ignore_virtual: |
|
if ( |
|
get_virtual_pipeline_model_parallel_world_size() is not None |
|
and get_virtual_pipeline_model_parallel_rank() != 0 |
|
): |
|
return False |
|
return get_pipeline_model_parallel_rank() == 0 |
|
|
|
|
|
def is_pipeline_last_stage(ignore_virtual=False): |
|
"""Return True if in the last pipeline-model-parallel stage, False otherwise.""" |
|
return get_pipeline_model_parallel_rank() == ( |
|
get_pipeline_model_parallel_world_size() - 1 |
|
) |
|
|
|
|
|
def cum_mul(l: list): |
|
try: |
|
ret = 1 |
|
for one in l: |
|
ret *= one |
|
return ret |
|
except: |
|
return 0 |
|
__import__('ipdb').set_trace() |
|
|