Spaces:
Running
Running
import traceback as tb | |
from typing import Any, Dict, Tuple | |
WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary] | |
__all__ = ["CheckpointException"] | |
def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: | |
return (exc, tb.extract_tb(exc.__traceback__)) | |
def _is_wrapped_exception(obj: Any) -> bool: | |
if not isinstance(obj, tuple): | |
return False | |
if len(obj) != 2: | |
return False | |
return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) | |
class CheckpointException(BaseException): | |
"""Exception raised if failure was detected as part of a checkpoint load or save.""" | |
def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]): | |
super().__init__(msg, failures) | |
self._failures = failures | |
def failures(self) -> Dict[int, WRAPPED_EXCEPTION]: | |
"""Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" | |
return self._failures | |
def __str__(self): | |
str = f"CheckpointException ranks:{self._failures.keys()}\n" | |
for rank, exc_pair in self._failures.items(): | |
exc, trace = exc_pair | |
str += f"Traceback (most recent call last): (RANK {rank})\n" | |
if trace is not None: | |
str += "".join(tb.format_list(trace)) | |
str += "".join(tb.format_exception_only(type(exc), value=exc)) | |
return str | |