Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
import logging | |
import time | |
from typing import Any, Callable, Dict, List, Tuple, TypeVar | |
from typing_extensions import ParamSpec | |
import torch | |
import torch.distributed as dist | |
from torch.distributed.logging_handlers import _log_handlers | |
__all__: List[str] = [] | |
def _get_or_create_logger() -> logging.Logger: | |
logging_handler, log_handler_name = _get_logging_handler() | |
logger = logging.getLogger(f"c10d-{log_handler_name}") | |
logger.setLevel(logging.DEBUG) | |
formatter = logging.Formatter( | |
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" | |
) | |
logging_handler.setFormatter(formatter) | |
logger.propagate = False | |
logger.addHandler(logging_handler) | |
return logger | |
def _get_logging_handler(destination: str = "default") -> Tuple[logging.Handler, str]: | |
log_handler = _log_handlers[destination] | |
log_handler_name = type(log_handler).__name__ | |
return (log_handler, log_handler_name) | |
global _c10d_logger | |
_c10d_logger = _get_or_create_logger() | |
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: | |
if dist.is_initialized(): | |
msg_dict = { | |
"func_name": f"{func_name}", | |
"args": f"{args}, {kwargs}", | |
"pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type] | |
"backend": f"{dist.get_backend(kwargs.get('group'))}", | |
"world_size": f"{dist.get_world_size()}", | |
"group_size": f"{dist.get_world_size(kwargs.get('group'))}", | |
"global_rank": f"{dist.get_rank()}", | |
"local_rank": f"{dist.get_rank(kwargs.get('group'))}", | |
} | |
if msg_dict["backend"] == "nccl": | |
nccl_version = torch.cuda.nccl.version() | |
msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version) | |
else: | |
msg_dict = { | |
"func_name": f"{func_name}", | |
"args": f"{args}, {kwargs}", | |
} | |
return msg_dict | |
_T = TypeVar('_T') | |
_P = ParamSpec('_P') | |
def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: | |
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: | |
try: | |
return func(*args, **kwargs) | |
except Exception as error: | |
msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) | |
msg_dict["error"] = f"{error}" | |
_c10d_logger.debug(msg_dict) | |
raise | |
return wrapper | |
def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: | |
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: | |
t1 = time.time_ns() | |
func_return = func(*args, **kwargs) | |
time_spent = time.time_ns() - t1 | |
msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) | |
msg_dict["time_spent"] = f"{time_spent}ns" | |
_c10d_logger.debug(msg_dict) | |
return func_return | |
return wrapper | |