from abc import ABC from megatron.core.transformer.transformer_config import TransformerConfig from torch.nn.modules.module import _addindent from termcolor import colored def prehook_save_input_shape(func): def wrapper(self, *input_shapes, **kw_input_shapes): if len(input_shapes) + len(kw_input_shapes) == 0: if "_input_shape" in self.__dict__: return func(self, *self._input_shape, **self._kw_input_shapes) else: return 0 self._input_shape = input_shapes self._kw_input_shapes = kw_input_shapes return func(self, *self._input_shape, **self._kw_input_shapes) return wrapper class MetaBase(type): def __new__(cls, name, bases, attrs): if "num_activation" in attrs: attrs["num_activation"] = prehook_save_input_shape(attrs["num_activation"]) return super().__new__(cls, name, bases, attrs) class MemEstimator(metaclass=MetaBase): def __init__(self, *args, **kwargs): self._modules = {} pass def __repr__(self): # We treat the extra repr like the sub-module, one item per line extra_lines = [] # extra_repr = self.extra_repr() # # empty string will be split into list [''] # if extra_repr: # extra_lines = extra_repr.split("\n") child_lines = [] for key, module in self._modules.items(): mod_str = repr(module) mod_str = _addindent(mod_str, 2) child_lines.append("(" + key + "): " + mod_str) lines = extra_lines + child_lines stat = ( "\t/* n_params=" + colored(f"{self.num_parameter()/1024/1024:.2f}M", "red") + "\tn_act=" + colored(f"{self.num_activation()/1024/1024:.2f}M", "green") + " */" ) main_str = self._get_name() + stat + " (" if lines: # simple one-liner info, which most builtin Modules will use if len(extra_lines) == 1 and not child_lines: main_str += extra_lines[0] else: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str return f"{self.__class__.__name__} n_param={self.num_parameter()}" def dump(self): ret = {} ret['name'] = self._get_name() ret['n_params'] = self.num_parameter() ret['n_act'] = self.num_activation() modules = {} for key, module in self._modules.items(): modules[key] = module.dump() if len(modules)>0: ret['modules'] = modules return ret def _get_name(self): return self.__class__.__name__ def num_parameter(self): """ Calculate number of the model parameters """ raise NotImplemented def num_activation(self, input_shape: list[int]): """ Calculate number of the activation with given input_shape. Args: input shape """ raise NotImplemented def mock_forward(self, input_shape: list[int]): """ Mock the forward. Args: input shape return: output shape """ raise NotImplemented def __setattr__(self, name: str, value) -> None: if isinstance(value, MemEstimator): modules = self.__dict__.get("_modules") modules[name] = value else: pass return super().__setattr__(name, value) def __delattr__(self, name): modules = self.__dict__.get("_modules") if name in modules: del modules[name] return super().__delattr__(name) _global_config: TransformerConfig = None def set_global_config(cfg): global _global_config _global_config = cfg def get_tensor_model_parallel_world_size(): global _global_config return _global_config.tensor_model_parallel_size def get_tensor_model_parallel_rank(): return 0 def get_expert_tensor_parallel_world_size(): global _global_config return _global_config.expert_tensor_parallel_size def get_expert_tensor_parallel_rank(): return 0 _pp_rank = 0 def set_pipeline_model_parallel_rank(rank): global _pp_rank _pp_rank = rank def get_pipeline_model_parallel_rank(): global _pp_rank return _pp_rank def get_virtual_pipeline_model_parallel_rank(): return 0 def get_pipeline_model_parallel_world_size(): global _global_config return _global_config.pipeline_model_parallel_size def get_expert_model_parallel_rank(): return 0 def get_expert_model_parallel_world_size(): global _global_config return _global_config.expert_model_parallel_size def get_virtual_pipeline_model_parallel_world_size(): global _global_config return _global_config.virtual_pipeline_model_parallel_size def is_pipeline_first_stage(ignore_virtual=False): """Return True if in the first pipeline model-parallel stage, False otherwise.""" if not ignore_virtual: if ( get_virtual_pipeline_model_parallel_world_size() is not None and get_virtual_pipeline_model_parallel_rank() != 0 ): return False return get_pipeline_model_parallel_rank() == 0 def is_pipeline_last_stage(ignore_virtual=False): """Return True if in the last pipeline-model-parallel stage, False otherwise.""" return get_pipeline_model_parallel_rank() == ( get_pipeline_model_parallel_world_size() - 1 ) def cum_mul(l: list): try: ret = 1 for one in l: ret *= one return ret except: return 0 __import__('ipdb').set_trace()