Spaces:
Running
Running
# mypy: ignore-errors | |
import collections | |
import dataclasses | |
import functools | |
import inspect | |
import itertools | |
import sys | |
import types | |
from typing import Dict, List | |
import torch._C | |
import torch._numpy as tnp | |
import torch.utils._pytree as pytree | |
from .. import config, variables | |
from ..bytecode_transformation import create_call_function, create_instruction | |
from ..exc import unimplemented | |
from ..guards import GuardBuilder, install_guard | |
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource | |
from ..utils import ( | |
check_constant_args, | |
check_unspec_python_args, | |
identity, | |
is_tensor_base_attr_getter, | |
proxy_args_kwargs, | |
) | |
from .base import VariableTracker | |
from .functions import NestedUserFunctionVariable, UserFunctionVariable | |
from .user_defined import UserDefinedObjectVariable | |
class SuperVariable(VariableTracker): | |
def __init__(self, typevar, objvar=None, specialized=False, **kwargs): | |
super().__init__(**kwargs) | |
# typevar is the fist argument to super(). In the case where no argument | |
# is provided to super(), it is the __class__ object where | |
# the super() function is being called | |
self.typevar = typevar | |
# objvar here must be an instance or subtype of typevar. | |
# In the case where super() is called without arguments, it is the first argument | |
# to the current function where super() is called from (self for regular method, | |
# cls for a classmethod) | |
self.objvar = objvar | |
self.specialized = specialized # directly get attr from self.typevar if true | |
def reconstruct(self, codegen): | |
codegen(variables.BuiltinVariable(super)) | |
codegen(self.typevar) | |
if self.objvar is not None: | |
codegen(self.objvar) | |
codegen.extend_output(create_call_function(2, True)) | |
else: | |
codegen.extend_output(create_call_function(1, True)) | |
def _resolved_getattr_and_source(self, tx, name): | |
assert self.objvar, "1-arg super not implemented" | |
if self.specialized: | |
return getattr(self.typevar.as_python_constant(), name) | |
search_type = self.typevar.as_python_constant() | |
# The rest of this function does two things: | |
# - Walk the mro to find where the attribute comes from to be | |
# able to provide accurate source | |
# - Call the getattr to get the object | |
# Find the class object, where the function lives. | |
# When objvar is "self", use type(self), when objvar is "cls", use it as-is | |
type_to_use = self.objvar.python_type() | |
type_to_use_source = ( | |
TypeSource(self.objvar.source) if self.objvar.source else None | |
) | |
if issubclass(type_to_use, type): | |
type_to_use = self.objvar.value | |
type_to_use_source = self.objvar.source | |
source = None | |
if self.objvar.source is not None: | |
# Walk the mro tuple to find out the actual class where the | |
# attribute resides. | |
search_mro = type_to_use.__mro__ | |
start_index = search_mro.index(search_type) + 1 | |
for index in range(start_index, len(search_mro)): | |
if hasattr(search_mro[index], name): | |
# Equivalent of something like type(L['self']).__mro__[1].attr_name | |
source = AttrSource( | |
GetItemSource(AttrSource(type_to_use_source, "__mro__"), index), | |
name, | |
) | |
break | |
# TODO(jansel): there is a small chance this could trigger user code, prevent that | |
return getattr(super(search_type, type_to_use), name), source | |
def var_getattr(self, tx, name: str) -> "VariableTracker": | |
# Check if getattr is a constant. If not, delay the actual work by | |
# wrapping the result in GetAttrVariable. Mostly super is called with a | |
# method, so most of the work is delayed to call_function. | |
# | |
# We could have just implemented a const_getattr. However, super is | |
# special when it comes to finding sources. Compared to other VTs, super | |
# requires the attr name to walk the mro and find the actual source (and | |
# not just AttrSource). | |
value, source = self._resolved_getattr_and_source(self, name) | |
if not variables.ConstantVariable.is_literal(value): | |
return GetAttrVariable(self, name) | |
if source: | |
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) | |
return variables.ConstantVariable.create(value, source=source) | |
return variables.ConstantVariable.create(value) | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
inner_fn, source = self._resolved_getattr_and_source(self, name) | |
if inner_fn is object.__init__: | |
return LambdaVariable(identity) | |
elif inner_fn is torch.nn.Module.__init__: | |
objvar = self.objvar | |
from ..side_effects import AttributeMutationNew | |
if ( | |
isinstance(objvar, variables.UserDefinedObjectVariable) | |
and isinstance(objvar.mutable_local, AttributeMutationNew) | |
and not (args or kwargs) | |
): | |
tx.output.side_effects.store_attr( | |
objvar, | |
"__call_nn_module_init", | |
variables.ConstantVariable.create(True), | |
) | |
return variables.ConstantVariable.create(None) | |
else: | |
unimplemented("super() nn.Module.__init__") | |
elif isinstance(inner_fn, types.FunctionType): | |
return variables.UserFunctionVariable( | |
inner_fn, source=source | |
).call_function(tx, [self.objvar] + args, kwargs) | |
elif isinstance(inner_fn, types.MethodType): | |
return variables.UserMethodVariable( | |
inner_fn.__func__, self.objvar, source=source | |
).call_function(tx, args, kwargs) | |
elif ( | |
inner_fn is collections.OrderedDict.__getitem__ | |
and isinstance(self.objvar, variables.UserDefinedObjectVariable) | |
and self.objvar.source | |
and len(args) == 1 | |
and len(kwargs) == 0 | |
and args[0].is_python_constant() | |
): | |
from .builder import VariableBuilder | |
key = args[0].as_python_constant() | |
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))( | |
collections.OrderedDict.__getitem__(self.objvar.value, key) | |
) | |
elif inner_fn in ( | |
collections.OrderedDict.__setitem__, | |
object.__setattr__, | |
) and isinstance(self.objvar, variables.CustomizedDictVariable): | |
assert not kwargs and len(args) == 2 | |
return super(variables.CustomizedDictVariable, self.objvar).call_method( | |
tx, "__setitem__", args, kwargs | |
) | |
else: | |
unimplemented(f"non-function or method super: {inner_fn}") | |
class UnknownVariable(VariableTracker): | |
""" | |
It could be anything! | |
""" | |
class DelayGraphBreakVariable(UnknownVariable): | |
""" | |
Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. | |
""" | |
class ComptimeVariable(VariableTracker): | |
""" | |
This variable is special, it lets you execute arbitrary code at | |
Dynamo compile time | |
""" | |
def reconstruct(self, codegen): | |
raise NotImplementedError("comptime is special form") | |
def var_getattr(self, tx, name: str) -> "VariableTracker": | |
from ..comptime import comptime | |
# To support the comptime.print_graph convenience accessors | |
from .functions import UserFunctionVariable | |
return UserFunctionVariable( | |
getattr(comptime, name), source=AttrSource(self.source, name) | |
) | |
def call_function( | |
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | |
) -> "VariableTracker": | |
from ..comptime import ComptimeContext | |
# TODO: support an expression form as well | |
assert not kwargs | |
assert len(args) == 1 | |
fn = args[0] | |
if isinstance(fn, UserFunctionVariable): | |
fn.get_function()(ComptimeContext(tx)) | |
elif isinstance(fn, NestedUserFunctionVariable): | |
# We have to manually bind the freevars ourselves | |
code = fn.get_code() | |
assert not fn.closure, ( | |
"comptime function must not have free variables, " | |
f"but these variables were free: {code.co_freevars}" | |
) | |
func = types.FunctionType( | |
code, | |
fn.f_globals, | |
fn.fn_name.as_python_constant(), | |
tuple(fn.defaults.items) if fn.defaults else None, | |
# We could automatically promote free variables into | |
# ComptimeVar but this is confusing if you access | |
# a free variable that we actually DO have the runtime | |
# value for | |
# tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) | |
tuple(), | |
) | |
func(ComptimeContext(tx)) | |
else: | |
raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") | |
return variables.ConstantVariable.create(None) | |
class ClosureVariable(UnknownVariable): | |
def __init__(self, name, **kwargs): | |
super().__init__(**kwargs) | |
self.name = name | |
def reconstruct(self, codegen): | |
codegen.append_output(codegen.create_load_closure(self.name)) | |
# closure variable created by an inlined function | |
class InlinedClosureVariable(UnknownVariable): | |
def __init__(self, name, **kwargs): | |
super().__init__(**kwargs) | |
self.name = name | |
def reconstruct(self, codegen): | |
codegen.append_output(codegen.create_load_closure(self.name)) | |
class NewCellVariable(VariableTracker): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
class NewGlobalVariable(VariableTracker): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
class InspectSignatureVariable(VariableTracker): | |
"""represents inspect.signature(...)""" | |
def create(callable, **kwargs): | |
if kwargs: | |
unimplemented(f"inspect.signature with {kwargs}") | |
return InspectSignatureVariable(callable) | |
def __init__(self, inspected: VariableTracker, **kwargs): | |
super().__init__(**kwargs) | |
self.inspected = inspected | |
def var_getattr(self, tx, name: str) -> "VariableTracker": | |
if name == "parameters": | |
return variables.ConstDictVariable( | |
{ | |
variables.ConstantVariable.create(name): InspectParameterVariable() | |
for name in self.inspected.inspect_parameter_names() | |
}, | |
user_cls=dict, | |
) | |
return super().var_getattr(tx, name) | |
class InspectParameterVariable(VariableTracker): | |
"""This is not implemented, if used will graph break.""" | |
pass | |
def produce_trampoline_autograd_apply(fn_cls): | |
def trampoline_autograd_apply(*args, **kwargs): | |
return fn_cls.apply(*args, **kwargs) | |
trampoline_autograd_apply._origin = produce_trampoline_autograd_apply | |
return trampoline_autograd_apply | |
class AutogradFunctionVariable(VariableTracker): | |
"""represents a torch.autograd.Function subclass""" | |
def __init__(self, fn_cls, **kwargs): | |
super().__init__(**kwargs) | |
self.fn_cls = fn_cls | |
def call_apply(self, tx, args, kwargs): | |
requires_grad = False | |
def visit(node): | |
nonlocal requires_grad | |
if isinstance(node, variables.TensorVariable): | |
if node.requires_grad is not False: | |
requires_grad = True | |
if isinstance(node, variables.NNModuleVariable): | |
if node.is_training(tx): | |
requires_grad = True | |
return node | |
VariableTracker.apply(visit, (args, kwargs)) | |
if ( | |
requires_grad | |
and torch.is_grad_enabled() | |
and config.capture_autograd_function | |
): | |
# Note - this is the same check used in autograd/function.py, except inverted. | |
# If we want to support functorch transforms here, we will need to enable this. | |
if ( | |
self.fn_cls.setup_context | |
!= torch.autograd.function._SingleLevelFunction.setup_context | |
): | |
unimplemented( | |
"NYI - autograd.Function with custom setup_context method" | |
) | |
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] | |
if vjp_fn is not torch.autograd.Function.vjp: | |
unimplemented("NYI - User defind vjp") | |
jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] | |
if jvp_fn is not torch.autograd.Function.jvp: | |
unimplemented("NYI - User defind jvp") | |
from .higher_order_ops import AutogradFunctionApplyVariable | |
source = self.source | |
if source is None: | |
source = AttrSource( | |
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ | |
) | |
return AutogradFunctionApplyVariable( | |
self.fn_cls.forward, | |
self.fn_cls.backward, | |
source, | |
source=AttrSource(source, member="apply"), | |
).call_function(tx, args, kwargs) | |
if self.source: | |
source = AttrSource(self.source, "forward") | |
else: | |
source = None | |
fn = self.fn_cls.forward | |
ctx = AutogradFunctionContextVariable.create(tx) | |
args = [ctx, *args] | |
if isinstance(fn, types.FunctionType): | |
return variables.UserFunctionVariable(fn, source=source).call_function( | |
tx, args, kwargs | |
) | |
elif isinstance(fn, types.MethodType): | |
return variables.UserMethodVariable( | |
fn.__func__, | |
variables.UserDefinedClassVariable(self.fn_cls), | |
source=source, | |
).call_function(tx, args, kwargs) | |
else: | |
unimplemented( | |
f"non-function or method in subclass of torch.autograd.Function: {fn}" | |
) | |
def call_function(self, tx, args, kwargs): | |
return AutogradFunctionVariable(self.fn_cls) | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
): | |
from ..trace_rules import is_callable_allowed | |
from .builder import wrap_fx_proxy | |
if name == "apply": | |
if is_callable_allowed(self.fn_cls): | |
trampoline_autograd_apply = produce_trampoline_autograd_apply( | |
self.fn_cls | |
) | |
return wrap_fx_proxy( | |
tx=tx, | |
proxy=tx.output.create_proxy( | |
"call_function", | |
trampoline_autograd_apply, | |
*proxy_args_kwargs(args, kwargs), | |
), | |
) | |
else: | |
return self.call_apply(tx, args, kwargs) | |
else: | |
unimplemented(f"Unsupported method: {name}") | |
class SavedTensorBox: | |
tensors: List[VariableTracker] = dataclasses.field(default_factory=list) | |
class AutogradFunctionContextVariable(UserDefinedObjectVariable): | |
""" | |
Tracks an autograd.Function() context using mutation tracking in side_effects.py | |
""" | |
_nonvar_fields = { | |
"proxy", | |
"inference", | |
*UserDefinedObjectVariable._nonvar_fields, | |
} | |
def __init__( | |
self, | |
value, | |
value_type=None, | |
inference=False, | |
proxy=None, | |
saved_tensors=None, | |
**kwargs, | |
): | |
super().__init__(value=value, value_type=value_type, **kwargs) | |
self.inference = inference | |
self.proxy = proxy | |
self.saved_tensors = saved_tensors | |
def create(tx): | |
proxy = tx.output.create_proxy( | |
"call_function", torch.autograd.function.FunctionCtx, tuple(), {} | |
) | |
out = tx.output.side_effects.track_object_new( | |
None, | |
torch.autograd.function.FunctionCtx, | |
functools.partial( | |
AutogradFunctionContextVariable, | |
inference=True, | |
proxy=proxy, | |
saved_tensors=SavedTensorBox(), | |
), | |
{}, | |
) | |
proxy.node.meta["example_value"] = out.value | |
return out | |
def as_proxy(self): | |
if self.proxy is None: | |
unimplemented("proxy not set") | |
return self.proxy | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
if name != "save_for_backward": | |
unimplemented(f"autograd.Function context method: {name}") | |
if self.saved_tensors is None: | |
unimplemented( | |
"save_for_backward only supported on a newly constructed FunctionCtx" | |
) | |
if not self.inference: | |
assert self.source and not kwargs | |
tx.output.side_effects.track_save_for_backward(self, args) | |
# In eager mode, multiple calls to .save_for_backward() will overwrite previous calls. | |
if len(self.saved_tensors.tensors) > 0: | |
self.saved_tensors.tensors = [] | |
for arg in args: | |
self.saved_tensors.tensors.append(arg) | |
return variables.ConstantVariable.create(None) | |
def var_getattr(self, tx, name): | |
if name == "save_for_backward": | |
return LambdaVariable( | |
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) | |
) | |
if name == "saved_tensors" and self.saved_tensors is not None: | |
return variables.TupleVariable(list(self.saved_tensors.tensors)) | |
return super().var_getattr(tx, name) | |
class LambdaVariable(VariableTracker): | |
def __init__(self, fn, **kwargs): | |
super().__init__(**kwargs) | |
self.fn = fn | |
def call_function( | |
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | |
) -> "VariableTracker": | |
return self.fn(*args, **kwargs) | |
class GetAttrVariable(VariableTracker): | |
def __init__(self, obj, name, **kwargs): | |
super().__init__(**kwargs) | |
assert isinstance(obj, VariableTracker) | |
assert isinstance(name, str) | |
self.obj = obj | |
self.name = name | |
def __str__(self): | |
return f"{self.__class__.__name__}({self.obj}, {self.name})" | |
def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): | |
return getattr(base_proxy, attr) | |
def as_proxy(self): | |
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) | |
def const_getattr(self, tx, name): | |
if not isinstance(self.obj, variables.NNModuleVariable): | |
raise NotImplementedError() | |
step1 = tx.output.get_submodule(self.obj.module_key) | |
if self.name not in step1.__dict__: | |
raise NotImplementedError() | |
step2 = inspect.getattr_static(step1, self.name) | |
if name not in step2.__dict__: | |
raise NotImplementedError() | |
return inspect.getattr_static(step2, name) | |
def reconstruct(self, codegen): | |
codegen(self.obj) | |
codegen.extend_output(codegen.create_load_attrs(self.name)) | |
def call_function( | |
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | |
) -> "VariableTracker": | |
return self.obj.call_method(tx, self.name, args, kwargs) | |
class MethodWrapperVariable(VariableTracker): | |
def __init__(self, method_wrapper, **kwargs): | |
super().__init__(**kwargs) | |
self.method_wrapper = method_wrapper | |
def call_function( | |
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | |
) -> "VariableTracker": | |
if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( | |
args[0], variables.TensorVariable | |
): | |
assert len(args) == 1 and len(kwargs) == 0 | |
return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) | |
super().call_function(tx, args, kwargs) | |
def is_python_constant(self): | |
return True | |
def as_python_constant(self): | |
return self.method_wrapper | |
class GetSetDescriptorVariable(VariableTracker): | |
def __init__(self, desc, **kwargs): | |
super().__init__(**kwargs) | |
self.desc = desc | |
def var_getattr(self, tx, name): | |
if name == "__get__" and self.source: | |
from .builder import VariableBuilder | |
return VariableBuilder(tx, AttrSource(self.source, "__get__"))( | |
self.desc.__get__ | |
) | |
else: | |
return super().var_getattr(tx, name) | |
def is_python_constant(self): | |
return True | |
def as_python_constant(self): | |
return self.desc | |
class PythonModuleVariable(VariableTracker): | |
def __init__(self, value: types.ModuleType, **kwargs): | |
super().__init__(**kwargs) | |
self.value = value | |
self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") | |
def python_type(self): | |
return types.ModuleType | |
def as_python_constant(self): | |
return self.value | |
def __repr__(self): | |
return f"PythonModuleVariable({self.value})" | |
def call_hasattr(self, tx, name): | |
if self.is_torch: | |
result = hasattr(self.value, name) | |
return variables.ConstantVariable.create(result) | |
return super().call_hasattr(tx, name) | |
class TypingVariable(VariableTracker): | |
def __init__(self, value, **kwargs): | |
super().__init__(**kwargs) | |
self.value = value | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
if name == "__getitem__" and len(args) == 1: | |
return variables.ConstantVariable.create( | |
self.value[args[0].as_python_constant()], | |
) | |
unimplemented("typing") | |
def python_type(self): | |
return type(self.value) | |
def as_python_constant(self): | |
return self.value | |
def get_np_to_tnp_map(): | |
from ..utils import NP_TO_TNP_MODULE | |
np_fn_to_tnp_fn = {} | |
for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): | |
for fn_name, tnp_fn in tnp_mod.__dict__.items(): | |
if callable(tnp_fn): | |
# some internal details do leak from tnp | |
# which are not part of numpy API. | |
if np_fn := getattr(np_mod, fn_name, None): | |
np_fn_to_tnp_fn[np_fn] = tnp_fn | |
return np_fn_to_tnp_fn | |
class NumpyVariable(VariableTracker): | |
""" | |
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. | |
""" | |
constant_fold_functions = (tnp.issubdtype,) | |
def __init__(self, value, **kwargs): | |
super().__init__(**kwargs) | |
self.value = value | |
def can_constant_fold_through(cls, fn): | |
mod = fn.__module__.split(".") | |
assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] | |
return fn in cls.constant_fold_functions | |
def call_function( | |
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | |
) -> "VariableTracker": | |
if not config.trace_numpy: | |
unimplemented(f"numpy.{self.value}()") | |
from ..utils import numpy_to_tensor_wrapper | |
from .tensor import NumpyNdarrayVariable | |
# lookup method name in tnp. Things like np.dtype(float) are not supported yet. | |
if self.value.__name__ == "dtype": | |
unimplemented( | |
f"numpy dtype function is not supported yet. Got type {type(self.value)}." | |
) | |
else: # We are dealing with a callable. | |
func = get_np_to_tnp_map().get(self.value) | |
if func is None: | |
unimplemented( | |
f"Can't find numpy function {self.value} in torch._numpy. " | |
" Please file an issue to request support for this function." | |
) | |
if ( | |
func.__module__ == "torch._numpy.random" | |
and config.use_numpy_random_stream | |
): | |
msg = f"delegate '{func.__qualname__}' to NumPy itself via " | |
msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}" | |
unimplemented(msg) | |
args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) | |
constant_args = check_constant_args(args, kwargs) | |
unspec_python_args = check_unspec_python_args(args, kwargs) | |
if self.can_constant_fold_through(func) and ( | |
constant_args or unspec_python_args | |
): | |
# constant fold | |
return variables.ConstantVariable.create( | |
self.as_python_constant()( | |
*[x.as_python_constant() for x in args], | |
**{k: v.as_python_constant() for k, v in kwargs.items()}, | |
), | |
) | |
# TODO Add all the functions that go from constants to constants to can_constant_fold_through | |
proxy = tx.output.create_proxy( | |
"call_function", | |
numpy_to_tensor_wrapper(func), | |
*proxy_args_kwargs(args, kwargs), | |
) | |
return NumpyNdarrayVariable.create(tx, proxy) | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
unimplemented("numpy") | |
def python_type(self): | |
return type(self.value) | |
def as_python_constant(self): | |
return self.value | |
def as_proxy(self): | |
if config.trace_numpy and isinstance(self.value, type): | |
# This handles numpy dtype attributes such as np.float32 | |
# We return a string as we don't want to serialize non-PyTorch objects in the output FX graph | |
# In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does | |
return self.value.__name__ | |
return super().as_proxy() | |
# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls | |
class NullVariable(VariableTracker): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def __str__(self): | |
return "NullVariable" | |
def reconstruct(self, codegen): | |
if sys.version_info < (3, 11): | |
unimplemented("cannot reconstruct NullVariable in < Python 3.11") | |
codegen.append_output(create_instruction("PUSH_NULL")) | |
class DeletedVariable(VariableTracker): | |
"""Marker used to implement delattr()""" | |
class StringFormatVariable(VariableTracker): | |
""" | |
Represents a call to str.format(), we delay calling format until after the graph. | |
""" | |
_nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} | |
def create(cls, format_string, sym_args, sym_kwargs): | |
if all( | |
x.is_python_constant() | |
for x in itertools.chain(sym_args, sym_kwargs.values()) | |
): | |
return variables.ConstantVariable.create( | |
format_string.format( | |
*[v.as_python_constant() for v in sym_args], | |
**{k: v.as_python_constant() for k, v in sym_kwargs.items()}, | |
) | |
) | |
return cls(format_string, list(sym_args), dict(sym_kwargs)) | |
def __init__(self, format_string, sym_args, sym_kwargs, **kwargs): | |
super().__init__(**kwargs) | |
assert isinstance(format_string, str) | |
self.format_string = format_string | |
self.sym_args = sym_args | |
self.sym_kwargs = sym_kwargs | |
def __repr__(self): | |
return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" | |
def reconstruct(self, codegen): | |
if sys.version_info >= (3, 11): | |
codegen.append_output(create_instruction("PUSH_NULL")) | |
codegen.append_output(codegen.create_load_const(self.format_string)) | |
codegen.append_output(codegen.create_load_attr("format")) | |
codegen(variables.TupleVariable(self.sym_args)) | |
kwargs = { | |
variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() | |
} | |
codegen(variables.ConstDictVariable(kwargs)) | |
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1)) | |
class DebuggingVariable(VariableTracker): | |
""" | |
Represents a call to a debugging function like print(), or something | |
registered to config.reorderable_logging_functions. | |
""" | |
def __init__(self, value, **kwargs): | |
super().__init__(**kwargs) | |
self.value = value | |
def is_reorderable_logging_function(obj): | |
return ( | |
callable(obj) | |
and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) | |
and obj in torch._dynamo.config.reorderable_logging_functions | |
) | |
def call_function(self, tx, args, kwargs): | |
if tx.export: | |
# For export cases, we can just make debugging functions no-ops | |
return | |
if not self.can_reorder_logs(self.value, args, kwargs): | |
unimplemented( | |
f"Reordering debugging function {self.value} " | |
f"with inputs {args} {kwargs} is not yet implemented." | |
) | |
tx.debug_locals.append((self, list(args))) | |
def reconstruct(self, codegen): | |
return self.source.reconstruct(codegen) | |
def can_reorder_logs(fn, args, kwargs) -> True: | |
""" | |
Run some additional checks for what sort of function calls can we | |
actually reorder. | |
""" | |
allowed_input_types = ( | |
variables.TensorVariable, | |
variables.ConstantVariable, | |
StringFormatVariable, | |
) | |
flat_args = pytree.tree_leaves([args, kwargs]) | |
for arg in flat_args: | |
if not isinstance(arg, allowed_input_types): | |
return False | |
return True | |