Spaces:
Running
Running
# mypy: ignore-errors | |
import functools | |
import inspect | |
from typing import Dict, List | |
import torch | |
from ...fx.experimental._backward_state import BackwardState | |
from .. import compiled_autograd, variables | |
from .._trace_wrapped_higher_order_op import trace_wrapped | |
from ..exc import unimplemented | |
from ..external_utils import call_module_hooks_from_backward_state | |
from ..guards import GuardBuilder, install_guard | |
from ..source import AttrSource, GlobalSource | |
from ..utils import istype | |
from .base import VariableTracker | |
from .constant import ConstantVariable | |
class DistributedVariable(VariableTracker): | |
""" | |
The base distributed variable that encapsulates common methods | |
for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.). | |
Concrete distributed objects could inherit this class and add object | |
specific logic. | |
i.e. It provides the check on the distributed package existance | |
and hold the tracking value for the corresponding distributed object. | |
""" | |
def __init__(self, value, **kwargs): | |
super().__init__(**kwargs) | |
if not DistributedVariable.is_available(): | |
unimplemented("torch.distributed package is not available!") | |
self.value = value | |
def python_type(self): | |
return type(self.value) | |
def is_available(): | |
# check if the distributed package is available or not | |
return torch.distributed.is_available() | |
def is_from_local(value): | |
if not DistributedVariable.is_available(): | |
return False | |
from torch.distributed._tensor import DTensor | |
return inspect.isfunction(value) and value is DTensor.from_local | |
def is_constant_pg_functions(value): | |
if not DistributedVariable.is_available(): | |
return False | |
from torch.distributed.distributed_c10d import ( | |
_get_group_size_by_name, | |
_get_group_tag, | |
_rank_not_in_group, | |
_resolve_group_name_by_ranks_and_tag, | |
get_process_group_ranks, | |
) | |
constant_processgroup_functions = [ | |
_get_group_size_by_name, | |
_get_group_tag, | |
_rank_not_in_group, | |
get_process_group_ranks, | |
_resolve_group_name_by_ranks_and_tag, | |
] | |
return inspect.isfunction(value) and value in constant_processgroup_functions | |
class PlacementClassVariable(DistributedVariable): | |
def is_placement_type(value): | |
# we can't rely on importing/accessing torch distributed, it is not always built. | |
if not DistributedVariable.is_available(): | |
return False | |
from torch.distributed._tensor.placement_types import Placement | |
return type(value) is type and issubclass(value, Placement) | |
def call_function( | |
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | |
) -> "VariableTracker": | |
if ( | |
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,) | |
and self.source | |
): | |
# NOTE: we don't need to track mutations to the placement class as they | |
# suppose to be immutable. | |
new_obj = object.__new__(self.value) | |
var = PlacementVariable(new_obj) | |
if inspect.getattr_static(self.value, "__init__", None): | |
var.call_method(tx, "__init__", args, kwargs) | |
return var | |
return super().call_function(tx, args, kwargs) | |
class PlacementVariable(DistributedVariable): | |
def is_placement(value): | |
# we can't rely on importing/accessing torch distributed, it is not always built. | |
if not DistributedVariable.is_available(): | |
return False | |
from torch.distributed._tensor.placement_types import Placement | |
return isinstance(value, Placement) | |
def as_python_constant(self): | |
return self.value | |
def var_getattr(self, tx, name: str) -> VariableTracker: | |
if name == "dim": | |
return ConstantVariable.create(self.value.dim) | |
return super().var_getattr(tx, name) | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
from . import ConstantVariable | |
# Placement types dynamo tracking only allows following methods | |
# and __setattr__ is for case like `Shard(dim)` and methods. | |
# Methods in the list must satisfy: | |
# 1. Input arguments are constants and do not need to be guarded on; | |
# 2. Output is constant with respect to their inputs | |
constant_fold_functions = [ | |
"__init__", | |
"__setattr__", | |
"is_shard", | |
"is_partial", | |
"is_replicate", | |
] | |
if name in constant_fold_functions: | |
try: | |
value_type = type(self.value) | |
assert ( | |
inspect.getattr_static(value_type, "__getattr__", None) is None | |
), "no custom getattr allowed!" | |
method = inspect.getattr_static(value_type, name) | |
except AttributeError: | |
method = None | |
if method is object.__init__: | |
return ConstantVariable.create(None) | |
args = [x.as_python_constant() for x in args] | |
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} | |
if name == "__setattr__": | |
method(self.value, *args, **kwargs) | |
return self | |
constant_val = method(self.value, *args, **kwargs) | |
return ConstantVariable.create(constant_val) | |
return super().call_method(tx, name, args, kwargs) | |
class DeviceMeshVariable(DistributedVariable): | |
def is_device_mesh(value): | |
# we can't rely on importing/accessing torch distributed, it is not always built. | |
if not DistributedVariable.is_available(): | |
return False | |
from torch.distributed.device_mesh import DeviceMesh | |
return istype(value, DeviceMesh) | |
def as_python_constant(self): | |
return self.value | |
def var_getattr(self, tx, name: str) -> VariableTracker: | |
if name == "ndim": | |
return ConstantVariable.create(self.value.ndim) | |
return super().var_getattr(tx, name) | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
if name == "size": | |
const_args = [x.as_python_constant() for x in args] | |
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} | |
return ConstantVariable.create(self.value.size(*const_args, **const_kwargs)) | |
if name == "get_coordinate": | |
return ConstantVariable.create(self.value.get_coordinate()) | |
if name == "get_group": | |
return ConstantVariable.create(self.value.get_group()) | |
if name == "_get_or_create_default_group": | |
return ProcessGroupVariable(self.value._get_or_create_default_group()) | |
return super().call_method(tx, name, args, kwargs) | |
class ProcessGroupVariable(DistributedVariable): | |
""" | |
We don't want a ProcessGroup object to end up in our output graph. | |
But it's common for dynamo to intercept a PG that is then used to get info like | |
rank() or world_size(), as well as passed to utility functions in distributed_c10d | |
which desugar it into plain types like a ranklist and tag. | |
For convenience and proper guarding, we construct a variable type. | |
TODO: make it possible to use ProcessGroupVariable as input to simple functions | |
like _expand_group without dynamo complaining about making a proxy for it. | |
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes | |
torch library functions are dealing with tensor-like types and would have proxies | |
for their args. | |
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors | |
or just graph-break whenever one of our special cases is not hit? | |
""" | |
def as_python_constant(self): | |
return self.value | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
if name == "rank": | |
return variables.ConstantVariable.create(self.value.rank()) | |
if name == "size": | |
return variables.ConstantVariable.create(self.value.size()) | |
return super().call_method(tx, name, args, kwargs) | |
def var_getattr(self, tx, name): | |
if name == "group_name": | |
return variables.ConstantVariable.create(self.value.group_name) | |
if name in ["rank", "size"]: | |
return variables.LambdaVariable( | |
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) | |
) | |
# TODO should this just raise unimplemented? | |
return super().var_getattr(tx, name) | |
def is_process_group(value): | |
# we can't rely on importing/accessing torch distributed, it is not always built. | |
if not DistributedVariable.is_available(): | |
return False | |
from torch._C._distributed_c10d import ProcessGroup | |
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup | |
return istype(value, (ProcessGroup, FakeProcessGroup)) | |
def get_global_pg_variable(): | |
""" | |
Make a ProcessGroupVariable from torch.distributed.group.WORLD and | |
intall guards. | |
""" | |
import torch.distributed as dist | |
source = AttrSource( | |
AttrSource( | |
base=AttrSource( | |
base=GlobalSource(global_name="torch"), | |
member="distributed", | |
get_static=False, | |
), | |
member="group", | |
get_static=False, | |
), | |
member="WORLD", | |
get_static=False, | |
) | |
install_guard(source.make_guard(GuardBuilder.ID_MATCH)) | |
return ProcessGroupVariable( | |
dist.group.WORLD, | |
source=source, | |
) | |
class BackwardHookVariable(VariableTracker): | |
""" | |
Handles torch.utils.hooks.BackwardHook for module-level backward | |
hooks. | |
""" | |
def create( | |
tx, | |
module: VariableTracker, | |
user_hooks: VariableTracker, | |
user_pre_hooks: VariableTracker, | |
): | |
if not compiled_autograd.compiled_autograd_enabled: | |
unimplemented("module-level backwards hooks require compiled autograd") | |
def _in_graph_bw_hooks(bw_state: BackwardState): | |
""" | |
Rather than installing the user hooks in the graph (which | |
don't survive AotAutograd), we install hooks that will call | |
trace_wrapped in the backward pass that CompiledAutograd | |
can turn into actual hook calls. | |
""" | |
return torch.utils.hooks.BackwardHook( | |
None, | |
( | |
functools.partial( | |
trace_wrapped, | |
fn=call_module_hooks_from_backward_state, | |
bw_state=bw_state, | |
hooks_name=user_hooks_name, | |
module_name=module_name, | |
), | |
), | |
( | |
functools.partial( | |
trace_wrapped, | |
fn=call_module_hooks_from_backward_state, | |
bw_state=bw_state, | |
hooks_name=user_pre_hooks_name, | |
module_name=module_name, | |
), | |
), | |
) | |
module_name, bw_state_proxy = tx.output.add_backward_state_hook(module) | |
user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks) | |
user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks) | |
proxy = tx.output.create_proxy( | |
"call_function", | |
_in_graph_bw_hooks, | |
(bw_state_proxy,), | |
{}, | |
) | |
proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ()) | |
return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks) | |
def __init__( | |
self, | |
proxy: torch.fx.Proxy, | |
module: VariableTracker, | |
user_hooks: VariableTracker, | |
user_pre_hooks: VariableTracker, | |
**options, | |
): | |
super().__init__(**options) | |
self.proxy = proxy | |
self.module = module | |
self.user_hooks = user_hooks | |
self.user_pre_hooks = user_pre_hooks | |
def as_proxy(self): | |
return self.proxy | |
def call_method( | |
self, | |
tx, | |
name, | |
args: List[VariableTracker], | |
kwargs: Dict[str, VariableTracker], | |
) -> VariableTracker: | |
if name in ("setup_input_hook", "setup_output_hook"): | |
return self._setup_hook(tx, name, *args, **kwargs) | |
return super().call_method(tx, name, args, kwargs) | |
def _setup_hook(self, tx, hook_method_name, args): | |
from .builder import wrap_fx_proxy | |
return wrap_fx_proxy( | |
tx, | |
tx.output.create_proxy( | |
"call_method", | |
hook_method_name, | |
(self.as_proxy(), args.as_proxy()), | |
{}, | |
), | |
) | |