Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import datetime | |
import logging | |
import os.path as osp | |
from typing import Optional | |
from mmengine.fileio import dump | |
from mmengine.logging import print_log | |
from . import root | |
from .default_scope import DefaultScope | |
from .registry import Registry | |
def traverse_registry_tree(registry: Registry, verbose: bool = True) -> list: | |
"""Traverse the whole registry tree from any given node, and collect | |
information of all registered modules in this registry tree. | |
Args: | |
registry (Registry): a registry node in the registry tree. | |
verbose (bool): Whether to print log. Defaults to True | |
Returns: | |
list: Statistic results of all modules in each node of the registry | |
tree. | |
""" | |
root_registry = registry.root | |
modules_info = [] | |
def _dfs_registry(_registry): | |
if isinstance(_registry, Registry): | |
num_modules = len(_registry.module_dict) | |
scope = _registry.scope | |
registry_info = dict(num_modules=num_modules, scope=scope) | |
for name, registered_class in _registry.module_dict.items(): | |
folder = '/'.join(registered_class.__module__.split('.')[:-1]) | |
if folder in registry_info: | |
registry_info[folder].append(name) | |
else: | |
registry_info[folder] = [name] | |
if verbose: | |
print_log( | |
f"Find {num_modules} modules in {scope}'s " | |
f"'{_registry.name}' registry ", | |
logger='current') | |
modules_info.append(registry_info) | |
else: | |
return | |
for _, child in _registry.children.items(): | |
_dfs_registry(child) | |
_dfs_registry(root_registry) | |
return modules_info | |
def count_registered_modules(save_path: Optional[str] = None, | |
verbose: bool = True) -> dict: | |
"""Scan all modules in MMEngine's root and child registries and dump to | |
json. | |
Args: | |
save_path (str, optional): Path to save the json file. | |
verbose (bool): Whether to print log. Defaults to True. | |
Returns: | |
dict: Statistic results of all registered modules. | |
""" | |
# import modules to trigger registering | |
import mmengine.dataset | |
import mmengine.evaluator | |
import mmengine.hooks | |
import mmengine.model | |
import mmengine.optim | |
import mmengine.runner | |
import mmengine.visualization # noqa: F401 | |
registries_info = {} | |
# traverse all registries in MMEngine | |
for item in dir(root): | |
if not item.startswith('__'): | |
registry = getattr(root, item) | |
if isinstance(registry, Registry): | |
registries_info[item] = traverse_registry_tree( | |
registry, verbose) | |
scan_data = dict( | |
scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), | |
registries=registries_info) | |
if verbose: | |
print_log( | |
f'Finish registry analysis, got: {scan_data}', logger='current') | |
if save_path is not None: | |
json_path = osp.join(save_path, 'modules_statistic_results.json') | |
dump(scan_data, json_path, indent=2) | |
print_log(f'Result has been saved to {json_path}', logger='current') | |
return scan_data | |
def init_default_scope(scope: str) -> None: | |
"""Initialize the given default scope. | |
Args: | |
scope (str): The name of the default scope. | |
""" | |
never_created = DefaultScope.get_current_instance( | |
) is None or not DefaultScope.check_instance_created(scope) | |
if never_created: | |
DefaultScope.get_instance(scope, scope_name=scope) | |
return | |
current_scope = DefaultScope.get_current_instance() # type: ignore | |
if current_scope.scope_name != scope: # type: ignore | |
print_log( | |
'The current default scope ' # type: ignore | |
f'"{current_scope.scope_name}" is not "{scope}", ' | |
'`init_default_scope` will force set the current' | |
f'default scope to "{scope}".', | |
logger='current', | |
level=logging.WARNING) | |
# avoid name conflict | |
new_instance_name = f'{scope}-{datetime.datetime.now()}' | |
DefaultScope.get_instance(new_instance_name, scope_name=scope) | |