Spaces:
Running
Running
File size: 13,902 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 |
# mypy: ignore-errors
import functools
import inspect
from typing import Dict, List
import torch
from ...fx.experimental._backward_state import BackwardState
from .. import compiled_autograd, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented
from ..external_utils import call_module_hooks_from_backward_state
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource
from ..utils import istype
from .base import VariableTracker
from .constant import ConstantVariable
class DistributedVariable(VariableTracker):
"""
The base distributed variable that encapsulates common methods
for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.).
Concrete distributed objects could inherit this class and add object
specific logic.
i.e. It provides the check on the distributed package existance
and hold the tracking value for the corresponding distributed object.
"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
if not DistributedVariable.is_available():
unimplemented("torch.distributed package is not available!")
self.value = value
def python_type(self):
return type(self.value)
@staticmethod
def is_available():
# check if the distributed package is available or not
return torch.distributed.is_available()
def is_from_local(value):
if not DistributedVariable.is_available():
return False
from torch.distributed._tensor import DTensor
return inspect.isfunction(value) and value is DTensor.from_local
def is_constant_pg_functions(value):
if not DistributedVariable.is_available():
return False
from torch.distributed.distributed_c10d import (
_get_group_size_by_name,
_get_group_tag,
_rank_not_in_group,
_resolve_group_name_by_ranks_and_tag,
get_process_group_ranks,
)
constant_processgroup_functions = [
_get_group_size_by_name,
_get_group_tag,
_rank_not_in_group,
get_process_group_ranks,
_resolve_group_name_by_ranks_and_tag,
]
return inspect.isfunction(value) and value in constant_processgroup_functions
class PlacementClassVariable(DistributedVariable):
@staticmethod
def is_placement_type(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed._tensor.placement_types import Placement
return type(value) is type and issubclass(value, Placement)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if (
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
and self.source
):
# NOTE: we don't need to track mutations to the placement class as they
# suppose to be immutable.
new_obj = object.__new__(self.value)
var = PlacementVariable(new_obj)
if inspect.getattr_static(self.value, "__init__", None):
var.call_method(tx, "__init__", args, kwargs)
return var
return super().call_function(tx, args, kwargs)
class PlacementVariable(DistributedVariable):
@staticmethod
def is_placement(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed._tensor.placement_types import Placement
return isinstance(value, Placement)
def as_python_constant(self):
return self.value
def var_getattr(self, tx, name: str) -> VariableTracker:
if name == "dim":
return ConstantVariable.create(self.value.dim)
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable
# Placement types dynamo tracking only allows following methods
# and __setattr__ is for case like `Shard(dim)` and methods.
# Methods in the list must satisfy:
# 1. Input arguments are constants and do not need to be guarded on;
# 2. Output is constant with respect to their inputs
constant_fold_functions = [
"__init__",
"__setattr__",
"is_shard",
"is_partial",
"is_replicate",
]
if name in constant_fold_functions:
try:
value_type = type(self.value)
assert (
inspect.getattr_static(value_type, "__getattr__", None) is None
), "no custom getattr allowed!"
method = inspect.getattr_static(value_type, name)
except AttributeError:
method = None
if method is object.__init__:
return ConstantVariable.create(None)
args = [x.as_python_constant() for x in args]
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
if name == "__setattr__":
method(self.value, *args, **kwargs)
return self
constant_val = method(self.value, *args, **kwargs)
return ConstantVariable.create(constant_val)
return super().call_method(tx, name, args, kwargs)
class DeviceMeshVariable(DistributedVariable):
@staticmethod
def is_device_mesh(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed.device_mesh import DeviceMesh
return istype(value, DeviceMesh)
def as_python_constant(self):
return self.value
def var_getattr(self, tx, name: str) -> VariableTracker:
if name == "ndim":
return ConstantVariable.create(self.value.ndim)
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "size":
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
return ConstantVariable.create(self.value.size(*const_args, **const_kwargs))
if name == "get_coordinate":
return ConstantVariable.create(self.value.get_coordinate())
if name == "get_group":
return ConstantVariable.create(self.value.get_group())
if name == "_get_or_create_default_group":
return ProcessGroupVariable(self.value._get_or_create_default_group())
return super().call_method(tx, name, args, kwargs)
class ProcessGroupVariable(DistributedVariable):
"""
We don't want a ProcessGroup object to end up in our output graph.
But it's common for dynamo to intercept a PG that is then used to get info like
rank() or world_size(), as well as passed to utility functions in distributed_c10d
which desugar it into plain types like a ranklist and tag.
For convenience and proper guarding, we construct a variable type.
TODO: make it possible to use ProcessGroupVariable as input to simple functions
like _expand_group without dynamo complaining about making a proxy for it.
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
torch library functions are dealing with tensor-like types and would have proxies
for their args.
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
or just graph-break whenever one of our special cases is not hit?
"""
def as_python_constant(self):
return self.value
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "rank":
return variables.ConstantVariable.create(self.value.rank())
if name == "size":
return variables.ConstantVariable.create(self.value.size())
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name == "group_name":
return variables.ConstantVariable.create(self.value.group_name)
if name in ["rank", "size"]:
return variables.LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
)
# TODO should this just raise unimplemented?
return super().var_getattr(tx, name)
@staticmethod
def is_process_group(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch._C._distributed_c10d import ProcessGroup
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
return istype(value, (ProcessGroup, FakeProcessGroup))
@staticmethod
def get_global_pg_variable():
"""
Make a ProcessGroupVariable from torch.distributed.group.WORLD and
intall guards.
"""
import torch.distributed as dist
source = AttrSource(
AttrSource(
base=AttrSource(
base=GlobalSource(global_name="torch"),
member="distributed",
get_static=False,
),
member="group",
get_static=False,
),
member="WORLD",
get_static=False,
)
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
return ProcessGroupVariable(
dist.group.WORLD,
source=source,
)
class BackwardHookVariable(VariableTracker):
"""
Handles torch.utils.hooks.BackwardHook for module-level backward
hooks.
"""
@staticmethod
def create(
tx,
module: VariableTracker,
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
):
if not compiled_autograd.compiled_autograd_enabled:
unimplemented("module-level backwards hooks require compiled autograd")
def _in_graph_bw_hooks(bw_state: BackwardState):
"""
Rather than installing the user hooks in the graph (which
don't survive AotAutograd), we install hooks that will call
trace_wrapped in the backward pass that CompiledAutograd
can turn into actual hook calls.
"""
return torch.utils.hooks.BackwardHook(
None,
(
functools.partial(
trace_wrapped,
fn=call_module_hooks_from_backward_state,
bw_state=bw_state,
hooks_name=user_hooks_name,
module_name=module_name,
),
),
(
functools.partial(
trace_wrapped,
fn=call_module_hooks_from_backward_state,
bw_state=bw_state,
hooks_name=user_pre_hooks_name,
module_name=module_name,
),
),
)
module_name, bw_state_proxy = tx.output.add_backward_state_hook(module)
user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
proxy = tx.output.create_proxy(
"call_function",
_in_graph_bw_hooks,
(bw_state_proxy,),
{},
)
proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ())
return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks)
def __init__(
self,
proxy: torch.fx.Proxy,
module: VariableTracker,
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
**options,
):
super().__init__(**options)
self.proxy = proxy
self.module = module
self.user_hooks = user_hooks
self.user_pre_hooks = user_pre_hooks
def as_proxy(self):
return self.proxy
def call_method(
self,
tx,
name,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> VariableTracker:
if name in ("setup_input_hook", "setup_output_hook"):
return self._setup_hook(tx, name, *args, **kwargs)
return super().call_method(tx, name, args, kwargs)
def _setup_hook(self, tx, hook_method_name, args):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
hook_method_name,
(self.as_proxy(), args.as_proxy()),
{},
),
)
|