Spaces:
Running
Running
import inspect | |
from typing import Any, Dict, List, Optional, Union | |
import torch.nn | |
from . import utils, variables | |
from .bytecode_transformation import ( | |
create_call_function, | |
create_call_method, | |
create_instruction, | |
) | |
from .codegen import PyCodegen | |
from .exc import unimplemented | |
from .source import LocalSource, Source | |
from .utils import nn_module_new, object_new | |
from .variables.base import ( | |
is_side_effect_safe, | |
MutableLocalBase, | |
MutableLocalSource, | |
VariableTracker, | |
) | |
class MutableSideEffects(MutableLocalBase): | |
""" | |
VariableTracker.mutable_local marker to indicate a list passed as | |
an input that if we mutate we need to re-apply those mutations after | |
the graph runs. | |
""" | |
def __init__(self, source: Source, is_modified: bool = False): | |
super().__init__(MutableLocalSource.Existing) | |
self.source = source | |
self.is_modified = is_modified | |
class AttributeMutation(MutableLocalBase): | |
""" | |
VariableTracker.mutable_local marker to track changes to attributes | |
""" | |
def __init__(self, typ: MutableLocalSource, source: Optional[Source]): | |
super().__init__(typ) | |
self.source = source | |
class AttributeMutationExisting(AttributeMutation): | |
def __init__(self, source: Source): | |
super().__init__(MutableLocalSource.Existing, source) | |
self.source = source | |
class AttributeMutationNew(AttributeMutation): | |
def __init__(self, source: Optional[Source], cls_source: Optional[Source]): | |
super().__init__(MutableLocalSource.Local, source) | |
self.cls_source = cls_source | |
class SideEffects: | |
""" | |
Track side effects (list mutation, setattr, etc) that need to be | |
applied after an FX graph is run. | |
""" | |
id_to_variable: Dict[int, VariableTracker] | |
store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]] | |
keepalive: List[Any] | |
def __init__( | |
self, | |
id_to_variable=None, | |
store_attr_mutations=None, | |
keepalive=None, | |
save_for_backward=None, | |
tensor_hooks=None, | |
): | |
super().__init__() | |
self.id_to_variable = id_to_variable or {} | |
self.store_attr_mutations = store_attr_mutations or {} | |
self.keepalive = keepalive or [] | |
self.save_for_backward = save_for_backward or [] | |
self.tensor_hooks = tensor_hooks or {} | |
def __eq__(self, other: object) -> bool: | |
assert isinstance(other, SideEffects) | |
# NB: do NOT test keepalive | |
return ( | |
self.id_to_variable == other.id_to_variable | |
and self.store_attr_mutations == other.store_attr_mutations | |
and self.save_for_backward == other.save_for_backward | |
and self.tensor_hooks == other.tensor_hooks | |
) | |
def diff(self, other: "SideEffects") -> Optional[str]: | |
if self.id_to_variable != other.id_to_variable: | |
sk_itv = self.id_to_variable.keys() | |
ok_itv = other.id_to_variable.keys() | |
if sk_itv != ok_itv: | |
return f"id_to_variable keys: {sk_itv} != {ok_itv}" | |
# Feel free to augment this with more fancy diffing logic | |
# if needed for debugging | |
return "id_to_variable: unknown diff" | |
elif self.store_attr_mutations != other.store_attr_mutations: | |
sk_sam = self.store_attr_mutations.keys() | |
ok_sam = other.store_attr_mutations.keys() | |
if sk_sam != ok_sam: | |
return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" | |
return "store_attr_mutations: unknown diff" | |
elif self.save_for_backward != other.save_for_backward: | |
return "save_for_backward" | |
elif self.tensor_hooks != other.tensor_hooks: | |
return "tensor_hooks" | |
else: | |
return None | |
def clone(self): | |
"""Create a shallow copy""" | |
return self.__class__( | |
id_to_variable=dict(self.id_to_variable), | |
store_attr_mutations={ | |
k: dict(v) for k, v in self.store_attr_mutations.items() | |
}, | |
keepalive=list(self.keepalive), | |
save_for_backward=self.save_for_backward, | |
tensor_hooks=self.tensor_hooks, | |
) | |
def apply(self, fn, cache=None, skip_fn=lambda _: False): | |
if cache is None: | |
cache = dict() | |
self.id_to_variable = { | |
k: VariableTracker.apply(fn, v, cache, skip_fn) | |
for k, v in self.id_to_variable.items() | |
} | |
self.store_attr_mutations = { | |
k: VariableTracker.apply(fn, v, cache, skip_fn) | |
for k, v in self.store_attr_mutations.items() | |
} | |
self.save_for_backward = VariableTracker.apply( | |
fn, self.save_for_backward, cache, skip_fn | |
) | |
self.tensor_hooks = VariableTracker.apply(fn, self.tensor_hooks, cache, skip_fn) | |
def __contains__(self, item): | |
return id(item) in self.id_to_variable | |
def __getitem__(self, item): | |
return self.id_to_variable[id(item)] | |
def check_allowed_side_effect(self, item): | |
from torch._dynamo.variables.misc import AutogradFunctionContextVariable | |
# People do things like self.dim = dim inside autograd.Function. | |
# These are benign. | |
if isinstance(item, AutogradFunctionContextVariable): | |
return True | |
if not is_side_effect_safe(item.mutable_local): | |
unimplemented( | |
"HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" | |
) | |
def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): | |
assert self.is_attribute_mutation(item) | |
self.check_allowed_side_effect(item) | |
if item.mutable_local not in self.store_attr_mutations: | |
self.store_attr_mutations[item.mutable_local] = {} | |
self.store_attr_mutations[item.mutable_local][name] = value | |
def load_attr(self, item, name, deleted_ok=False): | |
assert self.is_attribute_mutation(item) | |
result = self.store_attr_mutations[item.mutable_local][name] | |
if not deleted_ok and isinstance(result, variables.DeletedVariable): | |
unimplemented("read deleted attribute") | |
return result | |
def store_cell(self, cellvar, value): | |
assert isinstance(cellvar, variables.NewCellVariable) | |
assert isinstance(value, variables.VariableTracker) | |
self.store_attr(cellvar, "cell_contents", value) | |
def load_cell(self, cellvar): | |
assert isinstance(cellvar, variables.NewCellVariable) | |
return self.load_attr(cellvar, "cell_contents") | |
def load_global(self, gvar: VariableTracker, name: str): | |
assert isinstance(gvar, variables.VariableTracker) | |
return self.load_attr(gvar, name) | |
def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): | |
assert isinstance(gvar, variables.VariableTracker) | |
assert isinstance(value, variables.VariableTracker) | |
self.store_attr(gvar, name, value) | |
def cls_supports_mutation_side_effects(cls): | |
return inspect.getattr_static(cls, "__setattr__", None) in ( | |
object.__setattr__, | |
torch.nn.Module.__setattr__, | |
) | |
def is_attribute_mutation(self, item): | |
return isinstance(item.mutable_local, AttributeMutation) | |
def has_pending_mutation(self, item): | |
return self.is_attribute_mutation(item) and bool( | |
self.store_attr_mutations.get(item.mutable_local) | |
) | |
def is_modified(self, item): | |
if isinstance(item.mutable_local, AttributeMutationNew): | |
return True | |
if self.is_attribute_mutation(item): | |
return item.mutable_local in self.store_attr_mutations | |
return item.mutable_local.is_modified | |
def _track_obj( | |
self, | |
item: Any, | |
variable: VariableTracker, | |
mutable_cls=MutableSideEffects, | |
): | |
"""Start tracking a new variable for mutation""" | |
assert variable.source is not None | |
variable.mutable_local = mutable_cls(variable.source) | |
self.id_to_variable[id(item)] = variable | |
self.keepalive.append(item) | |
return variable | |
track_mutable = _track_obj | |
def track_object_existing( | |
self, | |
item: Any, | |
variable: VariableTracker, | |
): | |
return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting) | |
def track_object_new( | |
self, | |
cls_source: Source, | |
user_cls: Any, | |
variable_cls: Any, | |
options, | |
): | |
if user_cls is torch.autograd.function.FunctionCtx: | |
obj = torch.autograd.Function() | |
elif issubclass(user_cls, torch.nn.Module): | |
obj = nn_module_new(user_cls) | |
else: | |
obj = object_new(user_cls) | |
variable = variable_cls( | |
obj, | |
mutable_local=AttributeMutationNew(None, cls_source), | |
**options, | |
) | |
self.id_to_variable[id(obj)] = variable | |
self.keepalive.append(obj) | |
return variable | |
def track_cell_new( | |
self, | |
): | |
obj = object() | |
variable = variables.NewCellVariable( | |
mutable_local=AttributeMutationNew(None, None), | |
) | |
self.id_to_variable[id(obj)] = variable | |
self.keepalive.append(obj) | |
return variable | |
def track_cell_existing(self, source: Source, item: Any): | |
variable = variables.NewCellVariable( | |
mutable_local=AttributeMutationExisting(source), | |
) | |
self.id_to_variable[id(item)] = variable | |
self.keepalive.append(item) | |
return variable | |
def track_global_existing(self, source: Source, item: Any): | |
variable = variables.NewGlobalVariable( | |
mutable_local=AttributeMutationExisting(source), | |
) | |
self.id_to_variable[id(item)] = variable | |
self.keepalive.append(item) | |
return variable | |
def track_save_for_backward(self, ctx, args): | |
assert isinstance(ctx, variables.AutogradFunctionContextVariable) | |
self.save_for_backward.append((ctx, args)) | |
def track_tensor_variables_from_runahead_side_effects(self, other): | |
# In higher order ops we want to keep track of tensors seen in the | |
# speculate_subgraph so that we don't lift them again as a new input in | |
# other speculate_subgraph or in the root tracer. | |
for other_item in other.keepalive: | |
other_id = id(other_item) | |
other_variable = other.id_to_variable[other_id] | |
if other_id not in self.id_to_variable and isinstance( | |
other_variable, variables.TensorVariable | |
): | |
self.track_object_existing(other_item, other_variable) | |
def prune_dead_object_new(self, tx): | |
live_new_objects = set() | |
skip_obj = None | |
def visit(var: VariableTracker): | |
if ( | |
isinstance(var.mutable_local, AttributeMutationNew) | |
and var.mutable_local is not skip_obj | |
): | |
live_new_objects.add(var.mutable_local) | |
return var | |
def is_live(var: Union[MutableLocalBase, VariableTracker]): | |
if isinstance(var, AttributeMutationNew): | |
return var in live_new_objects | |
if isinstance(var, VariableTracker): | |
return is_live(var.mutable_local) | |
return True | |
VariableTracker.apply(visit, (tx.stack, tx.symbolic_locals)) | |
for var in self.id_to_variable.values(): | |
if not isinstance(var.mutable_local, AttributeMutationNew): | |
VariableTracker.apply(visit, var) | |
for skip_obj, setattrs in self.store_attr_mutations.items(): | |
VariableTracker.apply(visit, setattrs) | |
self.id_to_variable = { | |
k: v for k, v in self.id_to_variable.items() if is_live(v) | |
} | |
self.store_attr_mutations = { | |
k: v for k, v in self.store_attr_mutations.items() if is_live(k) | |
} | |
def mutation(self, var): | |
self.check_allowed_side_effect(var) | |
if isinstance(var.mutable_local, MutableSideEffects): | |
var.mutable_local = MutableSideEffects(var.mutable_local.source, True) | |
def _get_modified_vars(self): | |
return [var for var in self.id_to_variable.values() if self.is_modified(var)] | |
def codegen_save_tempvars(self, cg: PyCodegen): | |
for var in self._get_modified_vars(): | |
if isinstance( | |
var.mutable_local, (AttributeMutationExisting, AttributeMutationNew) | |
) and isinstance(var, variables.NewCellVariable): | |
cg.load_import_from(utils.__name__, "make_cell") | |
cg.extend_output(create_call_function(0, True)) | |
cg.add_cache(var) | |
if isinstance(var.mutable_local, AttributeMutationNew): | |
var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] | |
elif isinstance(var.mutable_local, AttributeMutationNew): | |
if isinstance(var, variables.AutogradFunctionContextVariable): | |
unimplemented("AutogradFunctionContextVariable escaped") | |
if "__call_nn_module_init" in self.store_attr_mutations.get( | |
var.mutable_local, {} | |
): | |
assert isinstance(var, variables.UnspecializedNNModuleVariable) | |
cg.load_import_from(utils.__name__, "nn_module_new") | |
else: | |
cg.load_import_from(utils.__name__, "object_new") | |
cg(var.mutable_local.cls_source) | |
cg.extend_output(create_call_function(1, True)) | |
cg.add_cache(var) | |
var.mutable_local.source = LocalSource(cg.tempvars[var]) | |
elif var in cg.tempvars: | |
assert cg.tempvars.get(var) is None | |
# subsequent usage should point to the original variable | |
cg(var.mutable_local.source) | |
cg.add_cache(var) | |
for ctx, args in self.save_for_backward: | |
cg(ctx.source) | |
cg.extend_output( | |
[create_instruction("LOAD_METHOD", argval="save_for_backward")] | |
) | |
for arg in args: | |
cg(arg) | |
cg.extend_output( | |
[ | |
*create_call_method(len(args)), | |
create_instruction("POP_TOP"), | |
] | |
) | |
def register_hook(self, tensor, hook, handle, name): | |
assert isinstance(tensor, variables.TensorVariable) | |
assert isinstance(hook, variables.VariableTracker) | |
assert ( | |
isinstance(handle, variables.RemovableHandleVariable) | |
and handle.mutable_local | |
) | |
assert hasattr(torch.Tensor, name) | |
idx = len(self.tensor_hooks.keys()) | |
# duplicate index possible because of self.remove_hook() | |
while idx in self.tensor_hooks: | |
idx += 1 | |
self.tensor_hooks[idx] = (tensor, hook, handle, name) | |
assert not handle.idx | |
handle.idx = idx | |
def remove_hook(self, idx): | |
del self.tensor_hooks[idx] | |
def codegen_hooks(self, cg): | |
for ( | |
tensor, | |
hook, | |
handle, | |
name, | |
) in self.tensor_hooks.values(): | |
# Note: [On tensor.register_hook] | |
# | |
# register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented | |
# when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries). | |
# | |
# For tensors with a source, we bypass direct inclusion of register_hook calls in the graph. | |
# Instead, these are tracked and stashed as a global variable, enabling their association with tensors in | |
# the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able | |
# tensors. Because a source indicates knowledge of this object outside the torch compile region, and | |
# because we are running residuals firmly before .backward() can be run, it is sound to invoke | |
# `register_hook` on a known tensor. | |
# | |
# For tensors without a source, we support a limited subset of hooks. Global functions only, and | |
# compiled_autograd must be enabled or we will graph break. | |
# | |
# Handling the Handle: When a user retains the register_hook result in a handle, we intercept the | |
# STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed | |
# bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the | |
# stack intact. | |
# | |
# Dynamo Tensor Hooks Workflow: | |
# - Functions passed to register_hook are lifted globally. | |
# - For tensors with sources: | |
# - In the "side_effects" phase of codegen, we iterate over tensors with hooks to: | |
# - Generate the tensor. | |
# - Issue a register_hook call on the tensor, linking to the globally stored function. | |
# - Incorporate a handle if one was established in the eager phase. | |
# - For tensors without sources: | |
# - We don't generate any instructions for registering a hook. | |
# - Handles from intermediary hooks are NYI. | |
# - We produce a call function that utilizes the trace_wrapped higher order op, closing over it. | |
# - We then manually insert the call function above into the graph. | |
# - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. | |
assert tensor.source, "Hooks on non input tensors NYI - should not get here" | |
cg(tensor) | |
cg.extend_output([cg.create_load_attr(name)]) | |
cg(hook) | |
cg.extend_output(create_call_function(1, True)) | |
# Adding the handle to the cache means RemovableHandleVariable().reconstruct() will | |
# be associated with the return value of register_hook(). This consumes the top of stack. | |
cg.add_cache(handle) | |
def codegen_update_mutated(self, cg: PyCodegen): | |
suffixes = [] | |
for var in self._get_modified_vars(): | |
if isinstance(var, variables.ListVariable): | |
# old[:] = new | |
cg(var, allow_cache=False) | |
cg(var.mutable_local.source) # type: ignore[attr-defined] | |
cg.extend_output( | |
[ | |
cg.create_load_const(None), | |
cg.create_load_const(None), | |
create_instruction("BUILD_SLICE", arg=2), | |
] | |
) | |
suffixes.append([create_instruction("STORE_SUBSCR")]) | |
elif isinstance(var, variables.ConstDictVariable): | |
cg.tx.output.update_co_names("clear") | |
cg.tx.output.update_co_names("update") | |
cg(var.mutable_local.source) # type: ignore[attr-defined] | |
cg.extend_output([create_instruction("LOAD_METHOD", argval="update")]) | |
cg(var, allow_cache=False) | |
cg(var.mutable_local.source) # type: ignore[attr-defined] | |
cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")]) | |
suffixes.append( | |
[ | |
*create_call_method(0), # clear | |
create_instruction("POP_TOP"), | |
*create_call_method(1), # update | |
create_instruction("POP_TOP"), | |
] | |
) | |
elif self.is_attribute_mutation(var): | |
for name, value in self.store_attr_mutations.get( | |
var.mutable_local, {} | |
).items(): | |
if isinstance(var, variables.NewGlobalVariable): | |
cg.tx.output.update_co_names(name) | |
cg(value) | |
suffixes.append( | |
[create_instruction("STORE_GLOBAL", argval=name)] | |
) | |
elif name == "__call_nn_module_init": | |
pass # handled in codegen_save_tempvars | |
elif isinstance(value, variables.DeletedVariable): | |
if isinstance( | |
var.mutable_local, AttributeMutationExisting | |
) and hasattr(getattr(var, "value", None), name): | |
cg.tx.output.update_co_names(name) | |
cg(var.mutable_local.source) | |
suffixes.append( | |
[create_instruction("DELETE_ATTR", argval=name)] | |
) | |
else: | |
cg.tx.output.update_co_names(name) | |
cg(value) | |
cg(var.mutable_local.source) | |
suffixes.append([create_instruction("STORE_ATTR", argval=name)]) | |
elif isinstance(var, variables.TupleIteratorVariable): | |
for _ in range(var.index): | |
cg.load_import_from(utils.__name__, "iter_next") | |
cg(var.mutable_local.source) # type: ignore[attr-defined] | |
cg.extend_output(create_call_function(1, True)) | |
cg.append_output(create_instruction("POP_TOP")) | |
else: | |
raise AssertionError(type(var)) | |
# do all the actual mutations at the very end to handle dependencies | |
for suffix in reversed(suffixes): | |
cg.extend_output(suffix) | |
def is_empty(self): | |
return not ( | |
any(map(self.is_modified, self.id_to_variable.values())) | |
or self.tensor_hooks | |
or self.save_for_backward | |
or self.tensor_hooks | |
) | |
def clear(self): | |
self.keepalive.clear() | |
self.id_to_variable.clear() | |