Spaces:
Running
Running
import collections | |
import dis | |
import functools | |
import itertools | |
import logging | |
import os | |
import random | |
import sys | |
import threading | |
import time | |
import traceback | |
import types | |
import typing | |
import weakref | |
from typing import Any, Callable, Dict, List, Optional, Set | |
from torch.fx._lazy_graph_module import ( # type: ignore[attr-defined] | |
_use_lazy_graph_module, | |
) | |
try: | |
import numpy as np | |
except ModuleNotFoundError: | |
np = None # type: ignore[assignment] | |
import torch | |
import torch._logging | |
from torch._guards import compile_context, CompileContext, CompileId, tracing | |
from torch._logging import structured | |
from torch._utils_internal import signpost_event | |
from torch.fx.experimental.symbolic_shapes import ( | |
ConstraintViolationError, | |
GuardOnDataDependentSymNode, | |
) | |
from torch.fx.graph_module import _forward_from_src as original_forward_from_src | |
from torch.nn.parallel.distributed import DistributedDataParallel | |
from torch.utils._python_dispatch import _disable_current_modes | |
from torch.utils._traceback import format_traceback_short | |
from . import config, exc, trace_rules | |
from .backends.registry import CompilerFn | |
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps | |
from .bytecode_transformation import ( | |
check_inst_exn_tab_entries_valid, | |
Instruction, | |
is_generator, | |
propagate_inst_exn_table_entries, | |
transform_code_object, | |
) | |
from .cache_size import ( | |
CacheSizeRelevantForFrame, | |
compute_cache_size, | |
exceeds_cache_size_limit, | |
is_recompilation, | |
) | |
from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher | |
from .exc import ( | |
augment_exc_message, | |
BackendCompilerFailed, | |
format_error_msg, | |
InternalTorchDynamoError, | |
TorchRuntimeError, | |
UncapturedHigherOrderOpError, | |
unimplemented, | |
Unsupported, | |
) | |
from .guards import ( | |
CheckFunctionManager, | |
get_and_maybe_log_recompilation_reason, | |
GuardedCode, | |
) | |
from .hooks import Hooks | |
from .output_graph import OutputGraph | |
from .replay_record import ExecutionRecord | |
from .symbolic_convert import InstructionTranslator, SpeculationLog | |
from .trace_rules import is_numpy | |
from .types import BytecodeHook | |
from .utils import ( | |
CleanupManager, | |
CompilationMetrics, | |
counters, | |
dynamo_timed, | |
format_bytecode, | |
frame_phase_timing, | |
gen_record_file_name, | |
increment_frame, | |
is_namedtuple, | |
istype, | |
LazyString, | |
maybe_cprofile, | |
orig_code_map, | |
record_compilation_metrics, | |
reset_graph_break_dup_checker, | |
setup_compile_debug, | |
troubleshooting_url, | |
write_record_to_file, | |
) | |
log = logging.getLogger(__name__) | |
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode") | |
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard | |
compile_lock = threading.RLock() | |
class Tracker: | |
def __init__(self): | |
self.seen = [] | |
self.seen_ids = set() | |
def add(self, strong_obj): | |
idx = id(strong_obj) | |
if idx not in self.seen_ids: | |
obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx)) | |
self.seen.append(obj) | |
self.seen_ids.add(idx) | |
def __contains__(self, item): | |
return id(item) in self.seen_ids | |
def clear(self): | |
self.seen.clear() | |
self.seen_ids.clear() | |
input_codes = Tracker() | |
output_codes = Tracker() | |
initial_global_state: Optional[GlobalStateGuard] = None | |
def fx_forward_from_src_skip_result(*args, **kwargs): | |
# we monkey patch FX to prevent infinite loop of trying to convert | |
# our generated code | |
result: types.FunctionType = original_forward_from_src(*args, **kwargs) | |
skip_code(result.__code__) | |
return result | |
def preserve_global_state(fn): | |
""" | |
Context manager to: | |
1) Save/restore torch.is_grad_enabled() state | |
2) Save/restore python random state | |
3) Save/restore torch random state | |
4) Monkey patch torch.fx.graph_module._forward_from_src | |
""" | |
def _fn(*args, **kwargs): | |
guards = GlobalStateGuard() | |
prior_grad_mode = torch.is_grad_enabled() | |
prior_inference_mode = torch.is_inference_mode_enabled() | |
prior_deterministic = torch.are_deterministic_algorithms_enabled() | |
prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled() | |
py_rng_state = random.getstate() | |
torch_rng_state = torch.random.get_rng_state() | |
if torch.cuda.is_available(): | |
cuda_rng_state = torch.cuda.get_rng_state() | |
prior_fwd_from_src = torch.fx.graph_module._forward_from_src | |
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result | |
cleanup = setup_compile_debug() | |
try: | |
return fn(*args, **kwargs) | |
finally: | |
cleanup.close() | |
torch._C._set_grad_enabled(prior_grad_mode) | |
torch.torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) | |
torch.use_deterministic_algorithms( | |
prior_deterministic, warn_only=prior_warn_only | |
) | |
random.setstate(py_rng_state) | |
torch.random.set_rng_state(torch_rng_state) | |
if torch.cuda.is_available(): | |
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] | |
torch.fx.graph_module._forward_from_src = prior_fwd_from_src | |
assert ( | |
guards.check() | |
), "Global state changed while dynamo tracing, please report a bug" | |
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] | |
return _fn | |
def has_tensor_in_frame(frame): | |
"""Check if the frame has torch.* related bits""" | |
# Check if the function was decorated using torch._dynamo.optimize | |
if frame.f_code in always_optimize_code_objects: | |
return True | |
# Check if there is global import of torch.* | |
for co_name in frame.f_code.co_names: | |
if co_name in frame.f_globals: | |
obj = frame.f_globals[co_name] | |
if isinstance(obj, types.ModuleType) and ( | |
obj.__name__.startswith("torch.") or obj is torch | |
): | |
return True | |
# ... or a global import of numpy.* | |
if np and config.trace_numpy and (obj is np or is_numpy(obj)): | |
return True | |
seen_ids: Dict[int, bool] = dict() | |
def has_tensor(obj): | |
"""Recursively check if the obj has a tensor""" | |
obj_id = id(obj) | |
if obj_id in seen_ids: | |
return seen_ids[obj_id] | |
seen_ids[obj_id] = False | |
if isinstance(obj, (torch.Tensor, torch.nn.Module)) or ( | |
istype(obj, type) and issubclass(obj, torch.nn.Module) | |
): | |
seen_ids[obj_id] = True | |
return seen_ids[obj_id] | |
elif ( | |
config.trace_numpy | |
and np | |
and (istype(obj, np.ndarray) or isinstance(obj, np.generic)) | |
): | |
seen_ids[obj_id] = True | |
return seen_ids[obj_id] | |
elif istype(obj, (list, tuple)): | |
seen_ids[obj_id] = any(has_tensor(v) for v in obj) | |
return seen_ids[obj_id] | |
elif istype(obj, dict): | |
# Some packages like pytest can be updated during runtime. So, make a | |
# copy of values to avoid issues like "RuntimeError: dictionary | |
# changed size during iteration" | |
values = list(obj.values()) | |
seen_ids[obj_id] = any(has_tensor(v) for v in values) | |
return seen_ids[obj_id] | |
elif istype(obj, (str, int, float, type(None), bool)): | |
seen_ids[obj_id] = False | |
return seen_ids[obj_id] | |
elif is_namedtuple(obj) and hasattr(obj, "_fields"): | |
seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields) | |
return seen_ids[obj_id] | |
else: | |
# if config.debug: | |
# print( | |
# f"Assuming that object of type {type(obj)} does not have a tensor" | |
# ) | |
return False | |
# Check if the passed arguments are of type Tensor | |
for value in frame.f_locals.values(): | |
if has_tensor(value): | |
return True | |
log.debug( | |
"skipping because no torch.* %s \ | |
%s %s", | |
frame.f_code.co_name, | |
frame.f_code.co_filename, | |
frame.f_code.co_firstlineno, | |
) | |
return False | |
def exception_handler(e, code, frame=None, export=False): | |
record_filename = None | |
if hasattr(e, "exec_record"): | |
record_filename = gen_record_file_name(e, code) | |
write_record_to_file(record_filename, e.exec_record) | |
e.record_filename = record_filename | |
augment_exc_message(e, export=export) | |
FRAME_COUNTER = 0 | |
FRAME_COMPILE_COUNTER: typing.Counter[int] = collections.Counter() | |
def convert_frame_assert( | |
compiler_fn: CompilerFn, | |
one_graph: bool = True, | |
export: bool = False, | |
export_constraints=None, | |
): | |
"""Fully convert a frame into an FX graph""" | |
reset_graph_break_dup_checker() | |
def _convert_frame_assert( | |
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0 | |
): | |
increment_frame() | |
code = frame.f_code | |
cache_size = compute_cache_size(frame, cache_entry) | |
recompile_reasons = None | |
if is_recompilation(cache_size): | |
recompile_reasons = get_and_maybe_log_recompilation_reason( | |
cache_entry, frame | |
) | |
input_codes.add(code) | |
if code in output_codes: | |
return None | |
if ( | |
os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") | |
and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name | |
): | |
return None | |
if code.co_name == "<genexpr>" and code.co_filename.endswith( | |
( | |
"transformers/file_utils.py", | |
"transformers/utils/generic.py", | |
"diffusers/utils/outputs.py", | |
) | |
): | |
# not needed, but cleans up torchbench error stats | |
return None | |
if code.co_name == "__setattr__": | |
# setattr could be tricky to handle generally, | |
# but also not likely useful to compile- skip the whole frame | |
return None | |
if code.co_name == "__init__" and code.co_filename.startswith( | |
os.path.dirname(torch.optim.__file__) | |
): | |
# optimizer support is still incomplete see | |
# test_state_dict in test/dynamo/test_optimizers.py | |
return None | |
# Check if the frame is generated by an exec builtin call | |
# TODO - Running exec generated frame seems propagates f_globals to the | |
# next frames. | |
if code.co_name == "<module>" and code.co_filename == "<string>": | |
return None | |
if ( | |
code.co_name == "<lambda>" | |
and code.co_filename == "<string>" | |
and not bool(frame.f_builtins) | |
): | |
# namedtuple subclass constructor. Empty builtins cause issue with | |
# len keyword in LIST_LEN guard. | |
return None | |
if is_generator(code): | |
unimplemented("generator") | |
exceeded, limit_type = exceeds_cache_size_limit(cache_size) | |
if exceeded: | |
def format_func_info(code): | |
return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})" | |
def format_guard_failures(): | |
assert recompile_reasons, "TODO(whc) any other recompile reasons?" | |
return recompile_reasons[-1] | |
log.warning( | |
"torch._dynamo hit config.%s (%s)\n" | |
" function: %s\n" | |
" last reason: %s\n" | |
'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n' | |
"To diagnose recompilation issues, see %s.", | |
limit_type, | |
getattr(config, limit_type), | |
format_func_info(code), | |
format_guard_failures(), | |
troubleshooting_url, | |
) | |
unimplemented(f"{limit_type} reached") | |
if not has_tensor_in_frame(frame): | |
return None | |
global initial_global_state | |
initial_global_state = GlobalStateGuard() | |
global FRAME_COUNTER | |
if "_id" not in frame_state: | |
frame_state["_id"] = FRAME_COUNTER | |
FRAME_COUNTER += 1 | |
frame_id = frame_state["_id"] | |
frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] | |
FRAME_COMPILE_COUNTER[frame_id] += 1 | |
compile_id = CompileId(frame_id, frame_compile_id) | |
signpost_event( | |
"dynamo", | |
"_convert_frame_assert._compile", | |
{ | |
"co_name": code.co_name, | |
"co_filename": code.co_filename, | |
"co_firstlineno": code.co_firstlineno, | |
"cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, | |
"accumulated_cache_size": cache_size.num_cache_entries, | |
}, | |
) | |
return _compile( | |
frame.f_code, | |
frame.f_globals, | |
frame.f_locals, | |
frame.f_builtins, | |
compiler_fn, | |
one_graph, | |
export, | |
export_constraints, | |
hooks, | |
cache_size, | |
frame, | |
frame_state=frame_state, | |
compile_id=compile_id, | |
skip=skip + 1, | |
) | |
_convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] | |
def _clone_with_backend(backend): | |
return convert_frame_assert(backend, one_graph, export, export_constraints) | |
_convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined] | |
return _convert_frame_assert | |
from collections import OrderedDict | |
from torch.utils.hooks import RemovableHandle | |
# we have to use `OrderedDict` to make `RemovableHandle` work. | |
_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict() | |
def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: | |
"""Register hooks for bytecode generated by Dynamo. The hook can do some | |
logging, as well as return a new code object to be used. Please refer | |
to `BytecodeHook` for the hook signature. | |
""" | |
handle = RemovableHandle(_bytecode_hooks) | |
_bytecode_hooks[handle.id] = hook | |
return handle | |
def _compile( | |
code: types.CodeType, | |
globals: Dict[str, object], | |
locals: Dict[str, object], | |
builtins: Dict[str, object], | |
compiler_fn: CompilerFn, | |
one_graph: bool, | |
export: bool, | |
export_constraints, | |
hooks: Hooks, | |
cache_size: CacheSizeRelevantForFrame, | |
frame: Optional[types.FrameType] = None, | |
frame_state=None, | |
compile_id=None, | |
*, | |
skip: int = 0, | |
) -> Optional[GuardedCode]: | |
from torch.fx.experimental.validator import ( | |
bisect, | |
BisectValidationException, | |
translation_validation_enabled, | |
ValidationException, | |
) | |
output: Optional[OutputGraph] = None | |
tracer: Optional[InstructionTranslator] = None | |
# This is shared across restarts | |
mutated_closure_cell_contents: Set[str] = set() | |
speculation_log = SpeculationLog() | |
torch._dynamo.callback_handler.run_start_callbacks() | |
def transform(instructions, code_options): | |
nonlocal output | |
nonlocal tracer | |
speculation_log.restart() | |
tracer = InstructionTranslator( | |
instructions, | |
code, | |
locals, | |
globals, | |
builtins, | |
code_options, | |
compiler_fn, | |
one_graph, | |
export, | |
export_constraints, | |
mutated_closure_cell_contents, | |
frame_state=frame_state, | |
speculation_log=speculation_log, | |
) | |
try: | |
with tracing(tracer.output.tracing_context), tracer.set_current_tx(): | |
tracer.run() | |
except exc.UnspecializeRestartAnalysis: | |
speculation_log.clear() | |
raise | |
except (exc.SpeculationRestartAnalysis, exc.SkipFrame): | |
raise | |
except Exception: | |
if translation_validation_enabled(): | |
bisect(tracer.output.shape_env) | |
raise | |
finally: | |
tracer.output.call_cleanup_hooks() | |
output = tracer.output | |
assert output is not None | |
assert output.output_instructions | |
instructions[:] = output.output_instructions | |
code_options.update(output.code_options) | |
if config.dead_code_elimination: | |
propagate_inst_exn_table_entries(instructions) | |
check_inst_exn_tab_entries_valid(instructions) | |
instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) | |
def compile_inner( | |
code: types.CodeType, | |
one_graph: bool, | |
hooks: Hooks, | |
transform: Callable[[List[Instruction], Dict[str, Any]], Any], | |
) -> Optional[GuardedCode]: | |
nonlocal output | |
for attempt in itertools.count(): | |
CompileContext.get().attempt = attempt | |
try: | |
out_code = transform_code_object(code, transform) | |
break | |
except exc.RestartAnalysis as e: | |
log.info( | |
"Restarting analysis due to %s", | |
LazyString(format_traceback_short, e.__traceback__), | |
) | |
if attempt > 100: | |
unimplemented("100+ RestartAnalysis() calls") | |
except exc.SkipFrame as e: | |
log.debug( | |
"Skipping frame %s %s \ | |
%s %s", | |
e, | |
code.co_name, | |
code.co_filename, | |
code.co_firstlineno, | |
) | |
if one_graph: | |
log.debug("No graph captured with one_graph=True") | |
return None | |
def log_bytecode(prefix, name, filename, line_no, code): | |
if bytecode_log.isEnabledFor(logging.DEBUG): | |
bytecode_log.debug( | |
format_bytecode(prefix, name, filename, line_no, code) | |
) | |
log_bytecode( | |
"ORIGINAL BYTECODE", | |
code.co_name, | |
code.co_filename, | |
code.co_firstlineno, | |
code, | |
) | |
log_bytecode( | |
"MODIFIED BYTECODE", | |
code.co_name, | |
code.co_filename, | |
code.co_firstlineno, | |
out_code, # type: ignore[possibly-undefined] | |
) | |
for hook in _bytecode_hooks.values(): | |
hook_output = hook(code, out_code) | |
if hook_output is not None: | |
out_code = hook_output | |
orig_code_map[out_code] = code | |
output_codes.add(out_code) | |
assert output is not None | |
# Tests for new code objects. | |
# The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c | |
# Only test once the code object is created. | |
# They are not tested during runtime. | |
def count_args(code): | |
import inspect | |
return ( | |
code.co_argcount | |
+ code.co_kwonlyargcount | |
+ bool(code.co_flags & inspect.CO_VARARGS) | |
+ bool(code.co_flags & inspect.CO_VARKEYWORDS) | |
) | |
total_argcount_old = count_args(code) | |
total_argcount_new = count_args(out_code) | |
msg = "arg mismatch: " | |
msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, " | |
msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}" | |
assert ( | |
code.co_varnames[:total_argcount_old] | |
== out_code.co_varnames[:total_argcount_new] | |
), msg | |
msg = "free var mismatch: " | |
msg += f"old code object has free var {code.co_freevars}, " | |
msg += f"new code object has free var {out_code.co_freevars}" | |
assert code.co_freevars == out_code.co_freevars, msg | |
msg = "cell var mismatch: " | |
msg += f"old code object has cell var {code.co_cellvars}, " | |
msg += f"new code object has cell var {out_code.co_cellvars}" | |
assert code.co_cellvars == out_code.co_cellvars, msg | |
# Skipping Dynamo on a frame without any extracted graph. | |
# This does not affect eager functionality. But this is necessary | |
# for export for cases where Dynamo-reconstructed bytecode can create | |
# new function frames, confusing export in thinking that there | |
# are extra graphs now. | |
if output.export and output.is_empty_graph(): | |
return None | |
assert output.guards is not None | |
CleanupManager.instance[out_code] = output.cleanups | |
check_fn = CheckFunctionManager( | |
output, | |
hooks.guard_fail_fn if hooks else None, | |
) | |
guarded_code = GuardedCode(out_code, check_fn.check_fn) | |
if not output.is_empty_graph() and hooks.guard_export_fn is not None: | |
# We should not run the guard_export_fn when Dynamo does not | |
# generate any graph. This can happen in export when TorchDynamo | |
# generated bytecode has some reconstruction logic for mutated | |
# variables which can trigger TorchDynamo on the children frames but | |
# they are benign and do not generate any new graphs. | |
hooks.guard_export_fn(output.guards) | |
return guarded_code | |
with compile_context(CompileContext(compile_id)): | |
log.debug( | |
"torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", | |
code.co_name, | |
code.co_filename, | |
code.co_firstlineno, | |
skip + 2, | |
# -2: omit current frame, omit contextlib decorator | |
"".join(traceback.format_list(traceback.extract_stack()[: -2 - skip])), | |
) | |
# -4: -2 as above, plus trace_structured frames | |
torch._logging.trace_structured( | |
"dynamo_start", | |
lambda: { | |
"stack": structured.from_traceback( | |
traceback.extract_stack()[: -4 - skip] | |
) | |
}, | |
) | |
start_time = time.time() | |
fail_type: Optional[str] = None | |
fail_reason: Optional[str] = None | |
fail_user_frame_filename: Optional[str] = None | |
fail_user_frame_lineno: Optional[int] = None | |
try: | |
guarded_code = compile_inner(code, one_graph, hooks, transform) | |
return guarded_code | |
except ( | |
Unsupported, | |
TorchRuntimeError, | |
BackendCompilerFailed, | |
AssertionError, | |
ConstraintViolationError, | |
GuardOnDataDependentSymNode, | |
ValidationException, | |
UncapturedHigherOrderOpError, | |
BisectValidationException, | |
) as e: | |
fail_type = str(type(e)) | |
fail_reason = str(e) | |
exception_handler(e, code, frame, export=export) | |
if e.innermost_user_frame_summary is not None: # type: ignore[union-attr] | |
fail_user_frame_filename = e.innermost_user_frame_summary.filename # type: ignore[union-attr] | |
fail_user_frame_lineno = e.innermost_user_frame_summary.lineno # type: ignore[union-attr] | |
raise | |
except Exception as e: | |
fail_type = str(type(e)) | |
fail_reason = str(e) | |
exception_handler(e, code, frame, export=export) | |
if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined] | |
fail_user_frame_filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined] | |
fail_user_frame_lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined] | |
raise InternalTorchDynamoError(str(e)).with_traceback( | |
e.__traceback__ | |
) from None | |
finally: | |
if tracer: | |
tracer.output.local_scope = {} | |
from .utils import curr_frame | |
frame_key = str(curr_frame) | |
if ( | |
fail_reason is None | |
and output is not None | |
and frame_key in frame_phase_timing | |
): | |
guard_count = len(output.guards) | |
shape_env_guard_count = len(output.shape_env.guards) | |
graph_op_count = output.count_calls() | |
graph_node_count = len(output.graph.nodes) | |
graph_input_count = len(output.placeholders) | |
entire_frame_compile_time = frame_phase_timing[frame_key].get( | |
"entire_frame_compile", None | |
) | |
backend_compile_time = frame_phase_timing[frame_key].get( | |
"backend_compile", None | |
) | |
inductor_compile_time = frame_phase_timing[frame_key].get( | |
"inductor_compile", None | |
) | |
code_gen_time = frame_phase_timing[frame_key].get("code_gen", None) | |
non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} | |
compliant_custom_ops = { | |
op.__qualname__ for op in output.compliant_custom_ops | |
} | |
else: | |
guard_count = None | |
shape_env_guard_count = None | |
graph_op_count = None | |
graph_node_count = None | |
graph_input_count = None | |
entire_frame_compile_time = None | |
backend_compile_time = None | |
inductor_compile_time = None | |
code_gen_time = None | |
non_compliant_ops = set({}) | |
compliant_custom_ops = set({}) | |
metrics = CompilationMetrics( | |
frame_key, | |
code.co_name, | |
code.co_filename, | |
code.co_firstlineno, | |
cache_size.num_cache_entries_with_same_id_matched_objs, | |
cache_size.num_cache_entries, | |
guard_count, | |
shape_env_guard_count, | |
graph_op_count, | |
graph_node_count, | |
graph_input_count, | |
start_time, | |
entire_frame_compile_time, | |
backend_compile_time, | |
inductor_compile_time, | |
code_gen_time, | |
fail_type, | |
fail_reason, | |
fail_user_frame_filename, | |
fail_user_frame_lineno, | |
non_compliant_ops, | |
compliant_custom_ops, | |
) | |
record_compilation_metrics(metrics) | |
torch._dynamo.callback_handler.run_end_callbacks() | |
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): | |
"""Try to convert a frame into an FX graph, if error leave frame unmodified""" | |
inner_convert = convert_frame_assert(compiler_fn, one_graph=False) | |
def _convert_frame( | |
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0 | |
): | |
counters["frames"]["total"] += 1 | |
try: | |
result = inner_convert( | |
frame, cache_entry, hooks, frame_state, skip=skip + 1 | |
) | |
counters["frames"]["ok"] += 1 | |
return result | |
except Exception as e: | |
# These two exception types are "soft" failure, in the sense that | |
# we know this is due to something we didn't implement all the | |
# way, scare the user less about it. That being said, if you | |
# are trying to understand why a graph break happened, it's still | |
# important to have this information, so offer it. | |
# | |
# NB: NotImplementedError used to be on this list, but actually | |
# it is impossible for it to reach here, as it is converted into | |
# InternalTorchDynamoError. This behavior seemed reasonable | |
# to me (ezyang, Aug 2023) so I kept it, but maybe at some point | |
# someone wanted these to also get suppressed. If so, you'll | |
# need to make these exceptions not get wrapped | |
# We intentionally don't want to suppress error here. | |
if isinstance(e, UncapturedHigherOrderOpError): | |
raise | |
soft_fail = isinstance(e, Unsupported) | |
if not config.suppress_errors and not soft_fail: | |
raise | |
# Suppress the error. NB: It's very important to do the | |
# suppression logging HERE, where the actual suppression | |
# happens. Previously it was somewhere else and so it was | |
# possible to accidentally not log at all. | |
record_filename = getattr(e, "record_filename", None) | |
code = frame.f_code | |
error_msg = format_error_msg(e, code, record_filename, frame) | |
if soft_fail: | |
log.info(error_msg, exc_info=True) | |
else: | |
log.warning(error_msg, exc_info=True) | |
return None | |
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] | |
_convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined] | |
return _convert_frame | |
# TODO mlazos: add support for same args, or record them | |
def replay(filename): | |
from .backends.debugging import eager | |
original_replay_val = config.replay_record_enabled | |
config.replay_record_enabled = False | |
with open(filename, "rb") as in_file: | |
record = ExecutionRecord.load(in_file) | |
record.globals = dict(itertools.chain(record.globals.items(), globals().items())) | |
try: | |
_compile( | |
record.code, | |
record.globals, | |
record.locals, | |
record.builtins, | |
compiler_fn=eager, | |
one_graph=False, | |
export=False, | |
export_constraints=None, | |
hooks=Hooks(), | |
cache_size=CacheSizeRelevantForFrame(0, 0), | |
frame=None, | |
frame_state={}, | |
) | |
finally: | |
config.replay_record_enabled = original_replay_val | |
def first_real_inst_idx(code): | |
if sys.version_info < (3, 11): | |
return 0 | |
for inst in dis.get_instructions(code): | |
if inst.opname == "RESUME": | |
return inst.offset // 2 | |
raise RuntimeError("RESUME instruction not found in code") | |
def catch_errors_wrapper(callback, hooks: Hooks): | |
def catch_errors(frame, cache_entry, frame_state): | |
assert frame_state is not None | |
is_skipfile = trace_rules.check(frame.f_code) | |
if ( | |
# TODO: the first condition is not covered by any test | |
frame.f_lasti >= first_real_inst_idx(frame.f_code) | |
or is_skipfile | |
or config.disable | |
): | |
if log.isEnabledFor(logging.DEBUG): | |
skip_reason = ( | |
"traced frame already" | |
if frame.f_lasti >= first_real_inst_idx(frame.f_code) | |
else "in skipfiles" | |
if trace_rules.check(frame.f_code) | |
else "dynamo tracing is disabled" | |
) | |
if not is_skipfile or config.verbose: | |
log.debug( | |
"skipping: %s (reason: %s, file: %s)", | |
frame.f_code.co_name, | |
skip_reason, | |
frame.f_code.co_filename, | |
) | |
return None | |
if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__": | |
# nametuple constructor | |
return None | |
if config._get_optimize_ddp_mode() == "ddp_optimizer": | |
ddp_module = DistributedDataParallel._get_active_ddp_module() | |
if ddp_module: | |
with compile_lock: | |
from torch._dynamo.backends.distributed import DDPOptimizer | |
ddp_optimizer = DDPOptimizer( | |
bucket_bytes_cap=ddp_module.bucket_bytes_cap, | |
backend_compile_fn=callback._torchdynamo_orig_callable, | |
) | |
assert hasattr( | |
callback, "_clone_with_backend" | |
), "DDPOptimizer only supports callback fns that know how to clone themselves." | |
hijacked_callback = callback._clone_with_backend( | |
ddp_optimizer.compile_fn, | |
) | |
return hijacked_callback(frame, cache_entry, hooks, frame_state) | |
with compile_lock, _disable_current_modes(): | |
# skip=1: skip this frame | |
return callback(frame, cache_entry, hooks, frame_state, skip=1) | |
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] | |
return catch_errors | |