Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import contextlib | |
import dataclasses | |
import enum | |
import functools | |
import logging | |
import threading | |
import traceback | |
import unittest.mock | |
import weakref | |
from abc import ABC, abstractmethod | |
from contextlib import contextmanager | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Generic, | |
List, | |
NamedTuple, | |
Optional, | |
Set, | |
Tuple, | |
TYPE_CHECKING, | |
TypeVar, | |
) | |
import torch | |
from torch.utils import _pytree as pytree | |
from torch.utils._traceback import CapturedTraceback | |
from torch.utils.weak import WeakTensorKeyDictionary | |
log = logging.getLogger(__name__) | |
if TYPE_CHECKING: | |
# Import the following modules during type checking to enable code intelligence features, | |
# such as auto-completion in tools like pylance, even when these modules are not explicitly | |
# imported in user code. | |
import sympy | |
""" | |
torch._guards is the definitional source of truth for general purpose guard structures. | |
An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions, | |
and no guard installation notions here. | |
""" | |
class CompileId(NamedTuple): | |
frame_id: int | |
# This id is per-frame, and counts how many times we've compiled this | |
# frame. This could have been a global id but having this be per-frame | |
# gives you a better intuitive sense for how many recompiles have occurred | |
# so far. | |
frame_compile_id: int | |
# TODO: consider also tracking the recompilation count | |
def __str__(self): | |
return f"{self.frame_id}/{self.frame_compile_id}" | |
class TraceId(NamedTuple): | |
compile_id: CompileId | |
# This starts off as 0, and every time we restart analysis it goes | |
# up by one | |
attempt: int | |
def __str__(self): | |
if self.attempt == 0: | |
return str(self.compile_id) | |
else: | |
return f"{self.compile_id}_{self.attempt}" | |
class GuardSource(enum.Enum): | |
LOCAL = 0 | |
GLOBAL = 1 | |
LOCAL_NN_MODULE = 2 | |
GLOBAL_NN_MODULE = 3 | |
CONSTANT = 4 | |
RANDOM_VALUE = 5 | |
SHAPE_ENV = 6 | |
LOCAL_FSDP_MODULE = 7 | |
GLOBAL_FSDP_MODULE = 8 | |
BACKWARD_STATE = 9 | |
EPHEMERAL = 10 | |
SYNTHETIC_LOCAL = 11 | |
def is_fsdp_module(self) -> bool: | |
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) | |
def is_nn_module(self) -> bool: | |
return ( | |
self | |
in ( | |
GuardSource.GLOBAL_NN_MODULE, | |
GuardSource.LOCAL_NN_MODULE, | |
) | |
or self.is_fsdp_module() | |
) | |
def is_local(self): | |
return self in ( | |
GuardSource.LOCAL, | |
GuardSource.LOCAL_NN_MODULE, | |
GuardSource.LOCAL_FSDP_MODULE, | |
) | |
""" | |
Base class for a "GuardBuilder" role. | |
The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little | |
confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference | |
to torchdynamo's GuardBuilder. | |
Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based | |
on GuardSource's select function. | |
There is value in keeping this GuardBuilderBase empty to keep layering clean. | |
""" | |
class GuardBuilderBase: | |
pass | |
class ShapeGuard(NamedTuple): | |
expr: sympy.Expr | |
stack: CapturedTraceback | |
class Guard: | |
# originating_source is the source that called the make_guard method to | |
# construct this guard object. The property name specifies what exactly it | |
# is the guard is guarding on. The meaning of the name is dependent on the | |
# create_fn; you must look at the use-site inside create_fn to know what | |
# name means. | |
# | |
# That being said, although you might think this is just a "name", name is | |
# usually an arbitrary Python expression that will be evaluated with all | |
# globals (and locals, if you create a LOCAL guard) to extract the Python | |
# object that we want to perform guard tests on. This evaluation | |
# typically happens in GuardBuilder.eval. In these cases, name is | |
# typically produced by originating_source.name() (not to be confused with | |
# GuardSource - the property source). | |
# | |
# Occasionally, name is not a valid Python expression; sometimes | |
# it is meaningless. Example create_fns that are like this include | |
# GRAD_MODE and SHAPE_ENV. | |
originating_source: Source | |
create_fn: Callable[[GuardBuilderBase, Guard], None] | |
# Export only. These values are written to at time of guard check_fn creation. | |
guard_types: Optional[List[str]] = None | |
code_list: Optional[List[str]] = None | |
obj_weakref: Optional[object] = None | |
guarded_class_weakref: Optional[type] = None | |
stack: Optional[CapturedTraceback] = None | |
user_stack: Optional[traceback.StackSummary] = None | |
_hash: Optional[int] = None | |
def __hash__(self): | |
if self._hash is None: | |
self._hash = hash((self.name, self.source, id(self.create_fn))) | |
return self._hash | |
def sort_key(self): | |
return ( | |
self.source.value if self.source else -1, | |
len(self.name), | |
self.name, | |
self.inner_create_fn().__code__.co_firstlineno, | |
) | |
def __lt__(self, other): | |
return self.sort_key() < other.sort_key() | |
def inner_create_fn(self): | |
if isinstance(self.create_fn, functools.partial): | |
return self.create_fn.func | |
else: | |
return self.create_fn | |
def name(self) -> str: | |
return self.originating_source.name() | |
def source(self) -> GuardSource: | |
return self.originating_source.guard_source() | |
def weakref_to_str(obj_weakref): | |
""" | |
This is a workaround of a Python weakref bug. | |
`obj_weakref` is instance returned by `weakref.ref`, | |
`str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: | |
class MyConfig(dict): | |
def __getattr__(self, x): | |
return self[x] | |
obj = MyConfig(offset=5) | |
obj_weakref = weakref.ref(obj) | |
str(obj_weakref) # raise error: KeyError: '__name__' | |
""" | |
if isinstance(obj_weakref, weakref.ReferenceType): | |
obj = obj_weakref() | |
if obj is not None: | |
return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>" | |
else: | |
return f"<weakref at {hex(id(obj_weakref))}; dead>" | |
else: | |
return str(obj_weakref) | |
def __repr__(self): | |
s = f""" | |
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__} | |
{{ | |
'guard_types': {self.guard_types}, | |
'code': {self.code_list}, | |
'obj_weakref': {self.weakref_to_str(self.obj_weakref)} | |
'guarded_class': {self.guarded_class_weakref} | |
}} | |
""" | |
return s | |
def __str__(self): | |
output = f"Name: {repr(self.name)}\n" | |
source = self.source.name.lower() if self.source else "" | |
output += f" Source: {source}\n" | |
output += f" Create Function: {self.inner_create_fn().__name__}\n" | |
output += f" Guard Types: {self.guard_types}\n" | |
output += f" Code List: {self.code_list}\n" | |
output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n" | |
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n" | |
return output | |
def create(self, builder: GuardBuilderBase): | |
try: | |
return self.create_fn(builder, self) | |
except Exception: | |
log.error("Error while creating guard:\n%s", str(self).rstrip()) | |
if self.stack: | |
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip()) | |
raise | |
def is_nn_module(self): | |
return self.source.is_nn_module() | |
def is_fsdp_module(self): | |
return self.source.is_fsdp_module() | |
def is_local(self): | |
return self.source.is_local() | |
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref): | |
if not self.guard_types: | |
self.guard_types = list() | |
self.guard_types.append(guard_type) | |
assert self.guarded_class_weakref in ( | |
guarded_class, | |
None, | |
), "Guarded class id must be identical, or None" | |
self.guarded_class_weakref = guarded_class | |
if not self.code_list: | |
self.code_list = code_list | |
else: | |
self.code_list.extend(code_list) | |
assert self.obj_weakref in ( | |
obj_weakref, | |
None, | |
), "Guarded object must be identical, or None" | |
self.obj_weakref = obj_weakref | |
T = TypeVar("T") | |
""" | |
Parent structure for guard env expressions. | |
A GuardEnvExpr can have any subtype. | |
Note: All subtypes must be handled exhaustively in | |
torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError. | |
""" | |
class GuardEnvExpr: | |
pass | |
""" | |
A class representing a pair of duplicate inputs. | |
input_pos_a and input_pos_b are input positions we have deduped. | |
""" | |
class DuplicateInputs(GuardEnvExpr): | |
input_source_a: Source | |
input_source_b: Source | |
def __post_init__(self): | |
assert self.input_source_a != self.input_source_b | |
""" | |
Checkpointable is an interface for driving state snapshotting, left purposely vague for now. | |
copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that | |
can also be taken in at restore_graphstate(T) calls. | |
When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable | |
does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet. | |
In the future, it will have a closer coupling to a generic Checkpoint management system. | |
""" | |
class Checkpointable(ABC, Generic[T]): | |
def copy_graphstate(self) -> T: | |
... | |
def restore_graphstate(self, state: T): | |
... | |
class GuardsCheckpointState: | |
""" | |
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext | |
""" | |
dynamo_guards: Set[Guard] = set() | |
def __init__(self, dynamo_guards): | |
self.dynamo_guards = dynamo_guards | |
def diff(self, other): | |
""" | |
Produces a delta against another GuardsCheckpointState. | |
Returns None if no delta is found, otherwise, return a set() of mismatched | |
Guard type objects. | |
""" | |
r = self.dynamo_guards.difference(other.dynamo_guards) | |
if len(r) == 0: | |
return None | |
return r | |
def __eq__(self, other): | |
return self.diff(other) is None | |
class ModuleContextCheckpointState: | |
nn_modules: Dict[str, torch.nn.Module] = {} | |
def __init__(self, nn_modules): | |
self.nn_modules = nn_modules | |
def diff(self, other): | |
""" | |
Produces a delta against another ModuleContextCheckpointState. | |
Returns None if no delta is found, otherwise, return a set() of mismatched | |
module key names. | |
""" | |
r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys())) | |
if len(r) == 0: | |
return None | |
return r | |
def __eq__(self, other): | |
return self.diff(other) is None | |
class ModuleContext(Checkpointable[ModuleContextCheckpointState]): | |
def __init__(self): | |
self.nn_modules: Dict[str, Any] = {} | |
def copy_graphstate(self): | |
return ModuleContextCheckpointState(dict(self.nn_modules)) | |
def restore_graphstate(self, state): | |
assert isinstance(state, ModuleContextCheckpointState) | |
self.nn_modules = state.nn_modules | |
class GlobalContextCheckpointState: | |
global_state: Dict[str, Tuple[Callable, ...]] = {} | |
def __init__(self, global_states): | |
self.global_state = global_states | |
def diff(self, other): | |
""" | |
Produces a delta against another GlobalContextCheckpointState. | |
Returns None if no delta is found, otherwise, return a set() of mismatched | |
global key names. | |
""" | |
r = set(self.global_state.keys()).difference(set(other.global_state.keys())) | |
if len(r) == 0: | |
return None | |
return r | |
def __eq__(self, other): | |
return self.diff(other) is None | |
class GlobalContext(Checkpointable[GlobalContextCheckpointState]): | |
""" | |
This keeps track of the global torch state during tracing of a function. | |
For example, torch.is_grad_enabled. | |
""" | |
_supported_global_states = { | |
"grad_enabled", | |
"torch_function_enabled", | |
"autocast_enabled", | |
"autocast_cpu_enabled", | |
"autocast_gpu_dtype", | |
"autocast_cpu_dtype", | |
"autocast_cache_enabled", | |
} | |
def __init__(self): | |
self.global_state: Dict[str, Tuple[Callable, ...]] = {} | |
def copy_graphstate(self): | |
return GlobalContextCheckpointState(dict(self.global_state)) | |
def restore_graphstate(self, state): | |
assert isinstance(state, GlobalContextCheckpointState) | |
self.global_state = state.global_state | |
assert ( | |
len(self.global_state) == len(self._supported_global_states) | |
and set(self.global_state.keys()) == self._supported_global_states | |
), "Global state mismatch" | |
for func, args in self.global_state.values(): | |
func(args) | |
""" | |
A GuardsContext is a checkpointable representation of all the guards in the current tracing | |
context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated | |
directly outside of it. For passing around internal state representations of this object, | |
prefer to extract them with copy_graphstate to produce a GuardsCheckpointState. | |
""" | |
# Like a Set[Guard] but will record the user stack on all guards at the | |
# time they were installed at their destination | |
class GuardsSet: | |
def __init__(self, inner=None): | |
if inner is None: | |
inner = set() | |
self.inner = inner | |
def __iter__(self): | |
return iter(self.inner) | |
def __len__(self): | |
return len(self.inner) | |
# Subtraction along with bool is typically used to determine the delta of | |
# added guards between checkpoints for higher order ops | |
def __sub__(self, other): | |
return GuardsSet(self.inner - other.inner) | |
def __bool__(self): | |
return bool(self.inner) | |
def add(self, guard: Guard, *, collect_debug_stack=True, skip=0): | |
if guard in self.inner: | |
return | |
if collect_debug_stack: | |
if guard.stack is None: | |
guard.stack = CapturedTraceback.extract(skip=1 + skip) | |
if guard.user_stack is None: | |
guard.user_stack = TracingContext.extract_stack() | |
self.inner.add(guard) | |
def update(self, *others: Set[Guard]): | |
for o in others: | |
for g in o: | |
self.add(g, skip=1) | |
def remove_guards_with_source(self, source): | |
"""Delete all guards with a given source""" | |
self.inner = {g for g in self.inner if g.originating_source != source} | |
class GuardsContext(Checkpointable[GuardsCheckpointState]): | |
def __init__(self): | |
self.dynamo_guards: GuardsSet = GuardsSet() | |
self.aotautograd_guards: List[GuardEnvExpr] = [] | |
def copy_graphstate(self): | |
return GuardsCheckpointState(set(self.dynamo_guards.inner)) | |
def restore_graphstate(self, state): | |
# NB: "steals" the passed in state | |
assert isinstance(state, GuardsCheckpointState) | |
self.dynamo_guards = GuardsSet(state.dynamo_guards) | |
_TLS = threading.local() | |
""" | |
TracingContext is the source of truth for all currently accumulated information | |
needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems | |
are open to managing their own TracingContext with that in mind. | |
The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid | |
having to plumb complex subsystems across multiple verticals. | |
Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor. | |
Accessing the current tracing context via | |
TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how | |
to plumb objects back up to where frame interpretation happened. | |
Note that you can end up with multiple TracingContext for a single compilation | |
of a frame, as we reset the TracingContext whenever we restart analysis. | |
CompileContext is a more overarching context that encompasses multiple restarts. | |
""" | |
class CompileContext: | |
def get() -> CompileContext: | |
assert _TLS.compile_context is not None | |
return _TLS.compile_context | |
def try_get() -> Optional[CompileContext]: | |
return getattr(_TLS, "compile_context", None) | |
def __init__(self, compile_id): | |
assert compile_id is None or isinstance(compile_id, CompileId) | |
self.compile_id: Optional[CompileId] = compile_id | |
self.attempt = 0 | |
def current_compile_id(): | |
self = CompileContext.try_get() | |
if self is None: | |
return None | |
return self.compile_id | |
def current_trace_id(): | |
self = CompileContext.try_get() | |
if self is None: | |
return None | |
if self.compile_id is None: | |
return None | |
return TraceId(self.compile_id, self.attempt) | |
class TracingContext: | |
""" | |
Provides the currently installed TracingContext, or None. | |
Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but | |
will return None. | |
""" | |
def try_get() -> Optional[TracingContext]: | |
return getattr(_TLS, "tracing_context", None) | |
def get() -> TracingContext: | |
if ctx := TracingContext.try_get(): | |
return ctx | |
raise RuntimeError( | |
"TracingContext.get() must be called within an ongoing trace." | |
) | |
def __init__(self, fake_mode): | |
self.guards_context = GuardsContext() | |
self.module_context = ModuleContext() | |
self.global_context = GlobalContext() | |
self.fake_mode = fake_mode | |
self.frame_summary_stack = [] | |
# This is morally part of frame_summary_stack, but it is kept separate | |
# for clarity. As we process a frame, this variable gets updated | |
# to keep track of what line we are in the function. We make a | |
# function call, this gets cleared and the frame location is pushed | |
# to frame_summary_stack (prepping this variable for the inner frame's | |
# progress) | |
self.loc_in_frame = None | |
# this is only set after aot_autograd | |
self.fw_metadata = None | |
self.params_flat = None | |
# this is for extended return calling convention from backend | |
# compiler to aot_autograd | |
# Per output, what the compiler specified stride of the output is, | |
# or None if no stride is known. This is always the HINT, it | |
# is never a SymInt (it would be better if it was a SymInt, but | |
# I can't conveniently get this from Inductor atm. Also, be | |
# careful not to accidentally induce guards on the SymInt if | |
# you ever do change this in aot_autograd.py; you should check | |
# on permutations preferentially.) | |
self.output_strides: Optional[List[Optional[List[int]]]] = None | |
# When this is True, whenever we encounter an int in Dynamo tracing, | |
# we will (1) force unspec it and (2) force it as a size-like unbacked | |
# integer. This is currently used when processing certain lists of | |
# ints that are known to be size-like and may have 0/1 entries that we | |
# must not specialize on. | |
self.force_unspec_int_unbacked_size_like = False | |
# See note [Tensor Fakification and Symbol Caching] | |
self.tensor_to_context = WeakTensorKeyDictionary() | |
# If this true, Aot Autograd will return output Fake Tensors with appropiate | |
# meta on the first invocation | |
# see note: [Returning Fake Tensors on First AOT Autograd Call] | |
self.fakify_first_call = False | |
def clear(self): | |
# Look at the note in output_graph.py in function `save_global_state` | |
# for the context on clearing global context. | |
self.global_context.global_state = {} | |
def patch(**kwargs): | |
prior = {} | |
ctx = TracingContext.get() | |
for key in kwargs.keys(): | |
# KeyError on invalid entry | |
prior[key] = getattr(ctx, key) | |
for key, val in kwargs.items(): | |
setattr(ctx, key, val) | |
try: | |
yield | |
finally: | |
for key, val in prior.items(): | |
setattr(ctx, key, val) | |
def extract_stack(): | |
self = TracingContext.try_get() | |
if self is None: | |
return traceback.StackSummary() | |
stack = self.frame_summary_stack | |
if self.loc_in_frame is not None: | |
stack = stack + [self.loc_in_frame] | |
return traceback.StackSummary.from_list(stack) | |
# Call this when you want to call into some code that isn't necessarily | |
# associated with the current frame state | |
def clear_frame(): | |
tc = TracingContext.get() | |
with unittest.mock.patch.object( | |
tc, "frame_summary_stack", [] | |
), unittest.mock.patch.object(tc, "loc_in_frame", None): | |
try: | |
yield | |
except Exception as e: | |
# Prevent real_stack from getting attached | |
# | |
# The invariant is that if an Exception as real_stack, we've | |
# appropriately attached a user stack and we no longer need to | |
# attach anything. Because we cannot conveniently interpose | |
# when an exception is thrown, we instead interpose everywhere | |
# we set what the user stack is set (using the context | |
# manager). However, our compiler stack does "tail calls" | |
# (when it calls into user compiler), at which point the | |
# parent exception frames would incorrectly attach an | |
# incorrect frame. | |
# | |
# However, if, somehow, someone raised an exception with this | |
# scope that had a stack (for example, because they are | |
# restoring the user stack state appropriately as they process | |
# node by node), we should respect it. Thus, we cannot | |
# unconditionally set None. | |
if not hasattr(e, "real_stack"): | |
e.real_stack = None # type: ignore[attr-defined] | |
raise | |
def current_frame(frame_summary): | |
# frame_summary can be None to solely take advantage of real_stack | |
# attachment to thrown exceptions | |
tc = TracingContext.get() | |
if frame_summary is not None: | |
tc.frame_summary_stack.append(frame_summary) | |
old = tc.loc_in_frame | |
tc.loc_in_frame = None | |
try: | |
yield | |
except Exception as e: | |
if not hasattr(e, "real_stack"): | |
e.real_stack = tc.extract_stack() # type: ignore[attr-defined] | |
raise | |
finally: | |
if frame_summary is not None: | |
tc.frame_summary_stack.pop() | |
tc.loc_in_frame = old | |
def report_output_strides(): | |
tc = TracingContext.try_get() | |
if tc is None: | |
yield None | |
return | |
old_output_strides = tc.output_strides | |
tc.output_strides = [] | |
try: | |
yield tc.output_strides | |
finally: | |
tc.output_strides = old_output_strides | |
def set_current_loc(filename, lineno, frame_name): | |
TracingContext.get().loc_in_frame = traceback.FrameSummary( | |
filename, lineno, frame_name | |
) | |
def compile_context(context: CompileContext): | |
old_context = getattr(_TLS, "compile_context", None) | |
_TLS.compile_context = context | |
try: | |
yield context | |
finally: | |
_TLS.compile_context = old_context | |
def tracing(context: Optional[TracingContext]): | |
""" | |
This function installs the passed in tracing context as a dynamic scoped | |
global variable. | |
Calls to TracingContext.get() while not under a `with tracing()` context | |
will return None. | |
""" | |
old_context = getattr(_TLS, "tracing_context", None) | |
_TLS.tracing_context = context | |
try: | |
yield context | |
except Exception as e: | |
if not hasattr(e, "real_stack") and context is not None: | |
e.real_stack = context.extract_stack() # type: ignore[attr-defined] | |
raise | |
finally: | |
if ( | |
context is not None | |
and context.fake_mode is not None | |
and context.fake_mode.shape_env is not None | |
): | |
context.fake_mode.shape_env.cleanup() | |
_TLS.tracing_context = old_context | |
# Subclasses can be found in torch/_dynamo/source.py | |
# TODO(voz): Consider a toplevel torch/_source.py | |
class Source: | |
def is_dict_key(self): | |
return False | |
def is_ephemeral(self): | |
return False | |
def reconstruct(self, codegen): | |
raise NotImplementedError() | |
def guard_source(self) -> GuardSource: | |
raise NotImplementedError() | |
def name(self) -> str: | |
raise NotImplementedError() | |
def make_guard(self, fn) -> Guard: | |
if self.guard_source() is GuardSource.CONSTANT: | |
raise NotImplementedError() | |
return Guard(self, fn) | |
def is_nn_module(self) -> bool: | |
return self.guard_source().is_nn_module() | |
def subguards_allowed(self): | |
"""True if you can guard on attributes of this""" | |
return self.guard_source() != GuardSource.SYNTHETIC_LOCAL | |
# Subclasses can be found in torch/_dynamo/source.py | |
class ChainedSource(Source): | |
base: Source | |
def is_dict_key(self): | |
# Recurse until you either hit a ConstDictKey or a Source | |
return self.base.is_dict_key() | |
def is_ephemeral(self): | |
return self.base.is_ephemeral() | |
def detect_fake_mode(inputs: Any = None): | |
""" | |
Attempts to "detect" what the current fake mode is. If there is one ambiently | |
available from TracingContext, we preferentially use that. Otherwise, we | |
heuristically detect the fake mode via the following sources, in order of | |
priority: | |
- Currently active fake mode on stack | |
- Fake mode associated with passed in tensors (inputs does not | |
have to be flattened) | |
""" | |
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode | |
fake_modes = [] | |
if context := TracingContext.try_get(): | |
fake_mode = context.fake_mode | |
if fake_mode is not None: | |
fake_modes.append((fake_mode, "tracing context", 0)) | |
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack | |
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())): | |
if isinstance(m, FakeTensorMode): | |
fake_modes.append((m, "active fake mode", i)) | |
flat_inputs = pytree.tree_leaves(inputs) | |
for i, flat_input in enumerate(flat_inputs): | |
if isinstance(flat_input, FakeTensor): | |
fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) | |
if fake_modes: | |
fake_mode, desc1, i1 = fake_modes[0] | |
for m, desc2, i2 in fake_modes[1:]: | |
assert fake_mode is m, ( | |
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" | |
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n" | |
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}" | |
) | |
return fake_mode | |
else: | |
return None | |
def active_fake_mode(): | |
""" | |
Inspects the dispatch mode stack for an active fake mode and returns it. | |
Returns None if no fake mode is active. | |
""" | |
from torch._subclasses.fake_tensor import FakeTensorMode | |
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack | |
for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())): | |
if isinstance(m, FakeTensorMode): | |
return m | |
return None | |