Spaces:
Running
Running
# mypy: ignore-errors | |
import collections | |
from enum import Enum | |
from typing import Any, Callable, Dict, List | |
from .. import variables | |
from ..current_scope_id import current_scope_id | |
from ..exc import unimplemented | |
from ..source import AttrSource, Source | |
from ..utils import identity, istype | |
class MutableLocalSource(Enum): | |
""" | |
If the VariableTracker.mutable_local represents a Variable that: | |
- already existed that Dynamo began tracking while introspection (Existing) | |
- is a new variable that is created during Dynamo introspection (Local) | |
""" | |
Existing = 0 | |
Local = 1 | |
class ParentsTracker: | |
""" | |
This is a perf optimization to limit the number of objects we need to visit in tx.replace_all. | |
This must be a seperate object so that it is not cloned in apply. | |
""" | |
def __init__(self): | |
# logically this is a set, but we use a dict to ensure deterministic ordering | |
self.parents: Dict[ParentsTracker, bool] = dict() | |
def add(self, parent): | |
self.parents[parent] = True | |
def recursive_parents(self): | |
rv = dict(self.parents) | |
worklist = list(self.parents) | |
while worklist: | |
for parent in worklist.pop().parents: | |
if parent not in rv: | |
assert isinstance(parent, ParentsTracker) | |
rv[parent] = True | |
worklist.append(parent) | |
return rv.keys() | |
class MutableLocalBase: | |
""" | |
Base class for Variable.mutable_local | |
""" | |
def __init__(self, typ: MutableLocalSource): | |
# In HigherOrderOperator tracing, we need to distinguish | |
# between MutableLocals inside the HigherOrderOperator and | |
# ones outside it. For example, it is not safe to mutate | |
# `a` in the following example because it was constructed | |
# in a different scope. | |
# | |
# def f(x): | |
# a = 1 | |
# def g(x): | |
# nonlocal a | |
# a = 2 | |
# return x | |
# return wrap(g, x) + a | |
# | |
# We use self.scope to distinguish this. | |
# scope == 0: The object was an existing variable | |
# scope == 1: The object was created while Dynamo | |
# was introspecting a function | |
# (and no HigherOrderOps were involved) | |
# scope >= 2: The object was created through | |
# Dynamo introspection of a HigherOrderOp. | |
# The exact number corresponds to the level | |
# of nested HigherOrderOps. | |
if typ is MutableLocalSource.Existing: | |
self.scope = 0 | |
elif typ is MutableLocalSource.Local: | |
self.scope = current_scope_id() | |
else: | |
unimplemented(f"Unsupported MutableLocalSource: {typ}") | |
class MutableLocal(MutableLocalBase): | |
""" | |
Marker used to indicate this (list, iter, etc) was constructed in | |
local scope and can be mutated safely in analysis without leaking | |
state. | |
""" | |
def __init__(self): | |
super().__init__(MutableLocalSource.Local) | |
def __hash__(self): | |
return id(self) | |
def __eq__(self, other): | |
return self is other | |
def _is_top_level_scope(scope_id): | |
return scope_id == 1 | |
def is_side_effect_safe(m: MutableLocalBase): | |
scope_id = current_scope_id() | |
# In the top-level scope (if no HigherOrderOperators are involved), | |
# we are allowed to modify variables created in this scope as well | |
# as existing variables. | |
if _is_top_level_scope(scope_id): | |
return True | |
# Otherwise, only allow local mutation of variables created in the current scope | |
return m.scope == scope_id | |
class VariableTrackerMeta(type): | |
def __call__(cls, *args, **kwargs): | |
"""Call __post_init__""" | |
obj = type.__call__(cls, *args, **kwargs) | |
obj.__post_init__(*args, **kwargs) | |
return obj | |
def __instancecheck__(cls, instance) -> bool: | |
"""Make isinstance work with LazyVariableTracker""" | |
if type.__instancecheck__( | |
variables.LazyVariableTracker, instance | |
) and cls not in ( | |
VariableTracker, | |
variables.LazyVariableTracker, | |
): | |
instance = instance.realize() | |
return type.__instancecheck__(cls, instance) | |
class VariableTracker(metaclass=VariableTrackerMeta): | |
""" | |
Base class for tracked locals and stack values | |
VariableTracker instances are immutable and should be copied in | |
order to change them. | |
""" | |
# fields to leave unmodified in apply() | |
_nonvar_fields = { | |
"value", | |
"guards", | |
"source", | |
"mutable_local", | |
"parents_tracker", | |
"user_code_variable_name", | |
} | |
def clone(self, **kwargs): | |
"""Shallow copy with some (optional) changes""" | |
args = dict(self.__dict__) | |
args.update(kwargs) | |
return self.__class__(**args) | |
def copy(cls, value): | |
"""Deeper (but not full) copy, leaving FX and user objects alone""" | |
return cls.apply(identity, value) | |
def apply( | |
cls, | |
fn: Callable[["VariableTracker"], "VariableTracker"], | |
value, | |
cache=None, | |
skip_fn=lambda _: False, # Whether we should skip applying to this var | |
): | |
""" | |
Walk this object and call fn on all the VariableTracker | |
instances | |
""" | |
if cache is None: | |
cache = dict() | |
idx = id(value) | |
if idx in cache: | |
return cache[idx][0] | |
if isinstance(value, VariableTracker): | |
if not skip_fn(value): | |
def update_object_dict(v): | |
changed = False | |
rv = v.__dict__ | |
for key in rv.keys(): | |
if key not in v._nonvar_fields: | |
prior = rv[key] | |
rv[key] = cls.apply(fn, prior, cache, skip_fn) | |
changed = changed or prior is not rv[key] | |
return v | |
value = value.unwrap() | |
was_realized = value.is_realized() | |
result = fn(update_object_dict(value)) | |
if not was_realized and value.is_realized(): | |
# running fn() resulted in value getting realized, | |
# which means we missed updating the contents of result | |
result = update_object_dict(result.unwrap()) | |
else: | |
result = fn(value) | |
if result is not None: | |
result = result.unwrap() | |
elif istype(value, list): | |
result = [cls.apply(fn, v, cache, skip_fn) for v in value] | |
elif istype(value, tuple): | |
result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value) | |
elif istype(value, (dict, collections.OrderedDict)): | |
result = { | |
k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items()) | |
} | |
else: | |
result = value | |
# save `value` to keep it alive and ensure id() isn't reused | |
cache[idx] = (result, value) | |
return result | |
def __repr__(self): | |
return f"{self.__class__.__name__}()" | |
def python_type(self): | |
""" | |
Abstract method to be implemented by subclasses of VariableTracker. | |
This method should return the type represented by the instance of the subclass. | |
The purpose is to provide a standardized way to retrieve the Python type information | |
of the variable being tracked. | |
Returns: | |
type: The Python type (such as int, str, list, etc.) of the variable tracked by | |
the subclass. If the type cannot be determined or is not relevant, | |
leaving it undefined or invoking super() is always sound. | |
Note: | |
This is an abstract method and may be overridden in subclasses. | |
Example: | |
class SetVariable(VariableTracker): | |
def python_type(self): | |
return set | |
Raises: | |
NotImplementedError: If the method is not implemented in a subclass. | |
""" | |
raise NotImplementedError(f"{self} has no type") | |
def as_python_constant(self): | |
"""For constants""" | |
raise NotImplementedError(f"{self} is not a constant") | |
def guard_as_python_constant(self): | |
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" | |
try: | |
return self.as_python_constant() | |
except NotImplementedError as e: | |
unimplemented(str(e)) | |
def is_python_constant(self): | |
try: | |
self.as_python_constant() | |
return True | |
except NotImplementedError: | |
return False | |
def make_guard(self, fn): | |
if self.source: | |
return self.source.make_guard(fn) | |
raise NotImplementedError() | |
def const_getattr(self, tx, name: str) -> Any: | |
"""getattr(self, name) returning a python constant""" | |
raise NotImplementedError() | |
def var_getattr(self, tx, name: str) -> "VariableTracker": | |
"""getattr(self, name) returning a new variable""" | |
value = self.const_getattr(tx, name) | |
if not variables.ConstantVariable.is_literal(value): | |
raise NotImplementedError() | |
source = None | |
if self.source: | |
source = AttrSource(self.source, name) | |
return variables.ConstantVariable.create(value, source=source) | |
def is_proxy(self): | |
try: | |
self.as_proxy() | |
return True | |
except NotImplementedError: | |
return False | |
def as_proxy(self): | |
raise NotImplementedError(str(self)) | |
def maybe_fx_node(self): | |
try: | |
proxy = self.as_proxy() | |
import torch.fx | |
if isinstance(proxy, torch.fx.Proxy): | |
return proxy.node | |
return None | |
except NotImplementedError: | |
return None | |
def reconstruct(self, codegen): | |
raise NotImplementedError() | |
def can_reconstruct(self, tx): | |
"""If it is possible to reconstruct the Python object this | |
VariableTracker represents.""" | |
assert tx is tx.output.root_tx, "Only root tx can reconstruct" | |
try: | |
from ..codegen import PyCodegen | |
cg = PyCodegen(tx) | |
self.reconstruct(cg) | |
return True | |
except NotImplementedError: | |
return False | |
def unpack_var_sequence(self, tx) -> List["VariableTracker"]: | |
raise NotImplementedError() | |
def has_unpack_var_sequence(self, tx) -> bool: | |
try: | |
self.unpack_var_sequence(tx) | |
return True | |
except NotImplementedError: | |
return False | |
def inspect_parameter_names(self) -> List[str]: | |
unimplemented(f"inspect_parameter_names: {self}") | |
def call_hasattr(self, tx, name: str) -> "VariableTracker": | |
unimplemented(f"hasattr {self.__class__.__name__} {name}") | |
def call_function( | |
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | |
) -> "VariableTracker": | |
unimplemented(f"call_function {self} {args} {kwargs}") | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
if name == "__len__" and self.has_unpack_var_sequence(tx): | |
assert not (args or kwargs) | |
return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) | |
elif ( | |
name == "__getattr__" | |
and len(args) == 1 | |
and args[0].is_python_constant() | |
and not kwargs | |
): | |
return self.var_getattr(tx, args[0].as_python_constant()) | |
raise unimplemented(f"call_method {self} {name} {args} {kwargs}") | |
def rename(self, tx, name): | |
return self | |
def realize(self) -> "VariableTracker": | |
"""Used by LazyVariableTracker to build the real VariableTracker""" | |
return self | |
def recursive_realize(self): | |
"""Realize all objects under this""" | |
return VariableTracker.apply(lambda x: x.realize(), self) | |
def unwrap(self) -> "VariableTracker": | |
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists""" | |
return self | |
def is_realized(self): | |
"""Used by LazyVariableTracker to indicate an unrealized node""" | |
return True | |
def __init__( | |
self, | |
*, | |
source: Source = None, | |
mutable_local: MutableLocal = None, | |
parents_tracker: ParentsTracker = None, | |
): | |
super().__init__() | |
self.source = source | |
self.mutable_local = mutable_local | |
self.parents_tracker = parents_tracker | |
def __post_init__(self, *args, **kwargs): | |
if self.parents_tracker is None: | |
self.parents_tracker = ParentsTracker() | |
# visit children 1 level deep and ensure parent is set properly | |
VariableTracker.apply( | |
lambda node: node.parents_tracker.add(self.parents_tracker), | |
[v for k, v in self.__dict__.items() if k not in self._nonvar_fields], | |
skip_fn=lambda _: True, | |
) | |
def typestr(*objs): | |
if len(objs) == 1: | |
(obj,) = objs | |
if isinstance(obj, VariableTracker): | |
return str(obj) | |
else: | |
return type(obj).__name__ | |
else: | |
return " ".join(map(typestr, objs)) | |