Spaces:
Running
Running
import os | |
import textwrap | |
from enum import auto, Enum | |
from traceback import extract_stack, format_exc, format_list, StackSummary | |
from typing import cast, NoReturn, Optional | |
import torch._guards | |
from . import config | |
from .utils import counters | |
def exportdb_error_message(case_name): | |
return ( | |
"For more information about this error, see: " | |
+ "https://pytorch.org/docs/main/generated/exportdb/index.html#" | |
+ case_name.replace("_", "-") | |
) | |
import logging | |
log = logging.getLogger(__name__) | |
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") | |
class TorchDynamoException(RuntimeError): | |
pass | |
class InternalTorchDynamoError(TorchDynamoException): | |
pass | |
class RestartAnalysis(TorchDynamoException): | |
pass | |
class SpeculationRestartAnalysis(RestartAnalysis): | |
pass | |
class UnspecializeRestartAnalysis(RestartAnalysis): | |
pass | |
class SkipFrame(TorchDynamoException): | |
pass | |
class TorchRuntimeError(TorchDynamoException): | |
pass | |
class InvalidBackend(TorchDynamoException): | |
def __init__(self, name): | |
super().__init__( | |
f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends." | |
) | |
class ResetRequired(TorchDynamoException): | |
def __init__(self): | |
super().__init__( | |
textwrap.dedent( | |
""" | |
Must call `torch._dynamo.reset()` before changing backends. Detected two calls to | |
`torch.compile()` with a different backend compiler arguments. | |
""" | |
) | |
) | |
class BackendCompilerFailed(TorchDynamoException): | |
def __init__(self, backend_fn, inner_exception): | |
self.backend_name = getattr(backend_fn, "__name__", "?") | |
self.inner_exception = inner_exception | |
msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" | |
super().__init__(msg) | |
class Unsupported(TorchDynamoException): | |
def __init__(self, msg): | |
super().__init__(msg) | |
self.real_stack = torch._guards.TracingContext.extract_stack() | |
self.msg = msg | |
self.category: Optional[str] = None | |
self.add_to_stats() | |
def remove_from_stats(self): | |
assert self.category is not None | |
counters[self.category][self.msg] -= 1 | |
if counters[self.category][self.msg] <= 0: | |
del counters[self.category][self.msg] | |
def add_to_stats(self, category="unimplemented"): | |
self.category = category | |
counters[category][self.msg] += 1 | |
class RecompileError(TorchDynamoException): | |
pass | |
class ArgsMismatchError(Unsupported): | |
def __init__(self, msg): | |
super().__init__(msg) | |
class AttributeMutationError(Unsupported): | |
def __init__(self, msg): | |
super().__init__(msg) | |
class CondOpArgsMismatchError(ArgsMismatchError): | |
""" | |
Internal error from cond() due to arguments mismatch. | |
""" | |
def __init__(self, msg): | |
super().__init__(msg) | |
class UserErrorType(Enum): | |
DYNAMIC_CONTROL_FLOW = auto() | |
ANTI_PATTERN = auto() | |
STANDARD_LIBRARY = auto() | |
CONSTRAINT_VIOLATION = auto() | |
DYNAMIC_DIM = auto() | |
INVALID_INPUT = auto() | |
INVALID_OUTPUT = auto() | |
class UserError(Unsupported): | |
def __init__(self, error_type: UserErrorType, msg, case_name=None): | |
""" | |
Type of errors that would be valid in Eager, but not supported in TorchDynamo. | |
The error message should tell user about next actions. | |
error_type: Type of user error | |
msg: Actionable error message | |
case_name: (Optional) Unique name (snake case) for the usage example in exportdb. | |
""" | |
if case_name is not None: | |
assert isinstance(case_name, str) | |
if msg.endswith("."): | |
msg += " " | |
else: | |
msg += "\n" | |
msg += exportdb_error_message(case_name) | |
super().__init__(msg) | |
self.error_type = error_type | |
self.message = msg | |
class UncapturedHigherOrderOpError(TorchDynamoException): | |
pass | |
class IncorrectUsage(Exception): | |
pass | |
# These exceptions are ok to fallback to eager/graph_break. | |
exceptions_allowed_to_be_fallback = ( | |
torch._subclasses.fake_tensor.DataDependentOutputException, | |
torch._subclasses.fake_tensor.DynamicOutputShapeException, | |
torch._subclasses.fake_tensor.UnsupportedOperatorException, | |
torch._subclasses.fake_tensor.UnsupportedFakeTensorException, | |
) | |
def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn: | |
# This function calls unimplemented internally and eventually graph breaks | |
# or falls to eager. unimplemented itself does not print any user warnings, | |
# i.e., its very silent. This helper function is intended when an error is | |
# encountered in the torch.compile stack which is worth showing as warning | |
# to the user. For example, if AOT Autograd backend fails with a fake tensor | |
# exception, its ok to fallback to eager but not silently. Here, we can use | |
# this function to log the message and the stack trace. | |
graph_break_msg = format_error_msg_verbose(e, code) | |
graph_breaks_log.debug("%s", graph_break_msg) | |
log.warning(msg) | |
raise unimplemented(msg) from e | |
def unimplemented(msg: str) -> NoReturn: | |
assert msg != os.environ.get("BREAK", False) | |
raise Unsupported(msg) | |
def warning(msg: str) -> None: | |
counters["warnings"][msg] += 1 | |
assert msg != os.environ.get("BREAK", False) | |
# KeyError has special handling for its args | |
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details | |
class KeyErrorMsg: | |
def __init__(self, value): | |
self.value = value | |
def __str__(self): | |
return str(self.value) | |
def __repr__(self) -> str: | |
return self.__str__() | |
def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: | |
import traceback | |
exc.innermost_user_frame_summary = None # type: ignore[attr-defined] | |
real_stack = get_real_stack(exc) | |
if real_stack is not None and len(real_stack) > 0: | |
exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined] | |
msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" | |
if config.replay_record_enabled and hasattr(exc, "record_filename"): | |
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ | |
torch._dynamo.replay('{exc.record_filename}').\n" | |
if not config.verbose and hasattr(exc, "real_stack"): | |
msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n' | |
if hasattr(exc, "inner_exception") and hasattr( | |
exc.inner_exception, "minifier_path" | |
): | |
if hasattr(exc.inner_exception, "buck_command"): | |
msg += ( | |
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " | |
f"this buck command to find the smallest traced graph " | |
f"which reproduces this error: {exc.inner_exception.buck_command}\n" | |
) | |
else: | |
msg += ( | |
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " | |
"this script to find the smallest traced graph which reproduces this error.\n" | |
) | |
if not config.suppress_errors and not export: | |
msg += ( | |
"\n\n" | |
"You can suppress this exception and fall back to eager by setting:\n" | |
" import torch._dynamo\n" | |
" torch._dynamo.config.suppress_errors = True\n" | |
) | |
old_msg = "" if len(exc.args) == 0 else str(exc.args[0]) | |
if isinstance(exc, KeyError): | |
exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] | |
else: | |
new_msg = old_msg + msg | |
exc.args = (new_msg,) + exc.args[1:] | |
def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]: | |
real_stack = getattr(exc, "real_stack", None) | |
if real_stack is None: | |
return None | |
# NB: it's possible for real_stack to be []; we still attempt to | |
# report a stack anyway because the stack_above_dynamo may still | |
# be useful for debugging | |
stack_above_dynamo = [] | |
if frame is not None: | |
# NB: frame is PyInterpreterFrame on Python 3.11 and later, | |
# not a TRUE frame object. You can't actually feed it | |
# to traceback because it doesn't have enough information. | |
# To solve this problem, we technically should just materialize | |
# the frame, the same way _PyFrame_GetFrameObject would do | |
# (but we cannot actually do this, because this populates | |
# frame_obj field, which default eval frame doesn't like). | |
# | |
# Fortunately, in this case, we can hack it: there's no need | |
# to actually use the truly top frame, we can just extract | |
# from where we are right now and rely on filter_stack to | |
# get rid of all the dynamo frames. For ease of testing | |
# we apply this behavior to ALL Python versions | |
stack_above_dynamo = filter_stack(extract_stack()) | |
return cast(StackSummary, stack_above_dynamo + real_stack) | |
# filter out all frames after entering dynamo | |
def filter_stack(stack): | |
user_stack = [] | |
for frame in stack: | |
if "convert_frame" in frame.filename: | |
break | |
if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line: | |
continue | |
user_stack.append(frame) | |
return user_stack | |
def format_error_msg_verbose( | |
exc: Exception, code, record_filename=None, frame=None | |
) -> str: | |
msg = ( | |
f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" | |
) | |
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" | |
msg += format_exc() | |
real_stack = get_real_stack(exc, frame) | |
if real_stack is not None: | |
msg += ( | |
"\n" | |
+ "=" * 10 | |
+ " The above exception occurred while processing the following code " | |
+ "=" * 10 | |
+ "\n\n" | |
) | |
msg += "".join(format_list(real_stack)) | |
msg += "\n" | |
msg += "=" * 10 | |
return msg | |
def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str: | |
msg = os.linesep * 2 | |
if config.verbose: | |
msg = format_error_msg_verbose(exc, code, record_filename, frame) | |
else: | |
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ | |
line {code.co_firstlineno} \ndue to: \n{format_exc()}" | |
return msg | |