Spaces:
Running
Running
import sys | |
import warnings | |
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union | |
import torch | |
import torch.distributed as dist | |
import torch.distributed.distributed_c10d as c10d | |
from torch._custom_ops import impl_abstract | |
from torch.distributed.device_mesh import DeviceMesh | |
from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode | |
from . import _functional_collectives_impl as fun_col_impl | |
from ._functional_collectives_impl import ( # noqa: F401 | |
_register_tensor_wrapper, | |
native_funcol_enabled, | |
) | |
try: | |
from torch.utils._cxx_pytree import tree_map_only | |
except ImportError: | |
from torch.utils._pytree import tree_map_only # type: ignore[no-redef] | |
if torch._running_with_deploy(): | |
def is_torchdynamo_compiling(): | |
"""Can't import torchdynamo in torchdeploy builds currently.""" | |
return False | |
else: | |
try: | |
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling | |
except Exception: | |
warnings.warn( | |
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" | |
) | |
def is_torchdynamo_compiling(): | |
return False | |
""" | |
New traceable, functional collectives. | |
RFC: https://github.com/pytorch/pytorch/issues/93173 | |
compiler: trace these ops with plain-old-data schemas, then choose how to lower them. | |
eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses, | |
automatically calling .wait() on underlying/hidden async 'work' obj only when fed to | |
a downstream op. | |
Issues: | |
* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files | |
* Proper support for eager requires inplace ops. We should explore having it as an option for the API. | |
""" | |
""" | |
Functional collectives are asynchronous only and we perform implicit stream synchronization | |
on behalf of the user. | |
We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness | |
first usage of the tensor and insert cross stream sync at the right place. | |
The above are the easy bits, the hard one is how we match the Work object returned by | |
c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective | |
op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the | |
dispatcher which might call other implementations that are allowed to change the returned | |
tensor - even return a tensor with a different shape (see ``torch.vmap``). | |
This means the caller of our ops receives a Tensor that is not guaranteed to be the same | |
allocated by our implementations and that makes pairing The AsyncTensor to the original | |
tensor a lot harder. This pairing is needed so we can lookup the Work object to use. | |
Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's | |
identity is not stable across dispatch, the op caller would end up with a different Tensor | |
instance that would not match any in the dictionary. | |
With Tensor identity out of the question, we decided use the tensor data pointer, which | |
should be stable across all the Tensor changes done during dispatch. | |
We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d. | |
We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait() | |
Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we | |
can clean up stale entries in the dictionary. | |
To eliminate the possibility of races we have a global version counter that is used by the finalizer. | |
As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo) | |
""" | |
""" | |
Functional collectives can accept any of these types to describe the ranks participating in collectives. | |
The different types will be desugared to a canonical format | |
""" | |
RANK_TYPES = Union[ | |
List[int], | |
List[List[int]], | |
dist.ProcessGroup, | |
DeviceMesh, | |
Tuple["dist._tensor.DeviceMesh", int], | |
str, | |
] | |
""" | |
User facing APIs for functional collectives | |
------------------------------------------- | |
These apis are called by user code and expected to work both in eager execution and compilation, | |
but there are significant differences to how the two modes are implemented underneath. | |
Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op) | |
just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization, | |
and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified | |
if sufficient subclass support is added in dynamo. | |
Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern. | |
Here's how it works under torch.compile/dynamo: | |
all_reduce(...) | |
|--> _expand_group(...) - desugars processgroup into canonical/traceable format | |
|--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper | |
|--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed | |
And under eager execution: | |
all_reduce(...) | |
|--> _expand_group(...) - same as above, but less critical for eager | |
|--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace | |
|--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor, | |
which issues wait_tensor() at the time of first use | |
""" | |
def wait_tensor(tensor): | |
""" | |
Wait on a tensor returned by the collectives ops. | |
Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA. | |
""" | |
if native_funcol_enabled(): | |
return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
else: | |
return torch.ops.c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""): | |
""" | |
Broadcasts the tensor to all processes in the given process group. | |
Args: | |
src (int): Source rank | |
group (ProcessGroup or List[int]): The process group to work on. | |
tag (str, optional): A unique identifier for the collective. Default: empty string | |
""" | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
tensor = torch.ops._c10d_functional.broadcast(self, src, group_name) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
tensor = torch.ops.c10d_functional.broadcast( | |
self, src, tag, rankset, group_size | |
) | |
return _maybe_wrap_tensor(tensor) | |
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""): | |
""" | |
Reduces the tensor data across all machines in such a way that all get | |
the final result. | |
The input tensor is left unmodified. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh | |
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover | |
that information and perform collective algebraic optimization. Use other forms of input for that. | |
""" | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
tensor = torch.ops._c10d_functional.all_reduce( | |
self, reduceOp.lower(), group_name | |
) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
tensor = torch.ops.c10d_functional.all_reduce( # type: ignore[attr-defined] | |
self, | |
reduceOp, | |
tag, | |
rankset, | |
group_size, | |
) | |
return _maybe_wrap_tensor(tensor) | |
def all_gather_tensor( | |
self: torch.Tensor, | |
gather_dim: int, | |
group: RANK_TYPES, | |
tag: str = "", | |
): | |
""" | |
Gather tensor data across from all machines and concatenate over ``gather_dim``. | |
Note that it currently only supports gather_dim = 0. | |
The input tensor is left unmodified. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh | |
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover | |
that information and perform collective algebraic optimization. Use other forms of input for that. | |
""" | |
assert self.is_contiguous() | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
group_size = c10d._get_group_size_by_name(group_name) | |
tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
self, group_size, group_name | |
) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
tensor = torch.ops.c10d_functional.all_gather_into_tensor( # type: ignore[attr-defined] | |
self, | |
tag, | |
rankset, | |
group_size, | |
) | |
res = _maybe_wrap_tensor(tensor) | |
# TODO this should be done inside AsyncCollectiveTensor to delay the wait() call | |
if gather_dim != 0: | |
# torch.cat access the data so we already need to wait here, first do wait | |
# and then chunk + cat avoid us going through ACT dispatching logic again | |
if isinstance(res, AsyncCollectiveTensor): | |
res = res.wait() # type: ignore[attr-defined] | |
res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) | |
return res | |
def reduce_scatter_tensor( | |
self: torch.Tensor, | |
reduceOp: str, | |
scatter_dim: int, | |
group: RANK_TYPES, | |
tag: str = "", | |
): | |
""" | |
Reduces the tensor data across all machines in such a way that all get | |
the final result, then scatter the results to corresponding ranks. | |
The input tensor is left unmodified. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh | |
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover | |
that information and perform collective algebraic optimization. Use other forms of input for that. | |
""" | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
group_size = c10d._get_group_size_by_name(group_name) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
assert ( | |
self.size(scatter_dim) % group_size == 0 | |
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" | |
if scatter_dim != 0: | |
tensor_list = torch.chunk(self, group_size, dim=scatter_dim) | |
self = torch.cat(tensor_list) | |
if native_funcol_enabled(): | |
tensor = torch.ops._c10d_functional.reduce_scatter_tensor( | |
self, | |
reduceOp.lower(), | |
group_size, | |
group_name, # type: ignore[possibly-undefined] | |
) | |
else: | |
tensor = torch.ops.c10d_functional.reduce_scatter_tensor( # type: ignore[attr-defined] | |
self, | |
reduceOp, | |
tag, | |
rankset, # type: ignore[possibly-undefined] | |
group_size, | |
) | |
res = _maybe_wrap_tensor(tensor) | |
return res | |
def all_reduce_coalesced( | |
self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" | |
) -> List[torch.Tensor]: | |
""" | |
Reduces a list of tensors across all machines in such a way that all get | |
the final result. | |
The all tensors in the input list are left unmodified. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh | |
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover | |
that information and perform collective algebraic optimization. Use other forms of input for that. | |
""" | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] | |
self, | |
reduceOp.lower(), | |
group_name, | |
) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
tensor_list = torch.ops.c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] | |
self, | |
reduceOp, | |
tag, | |
rankset, | |
group_size, | |
) | |
return list(map(_maybe_wrap_tensor, tensor_list)) | |
def all_gather_into_tensor_coalesced( | |
self: List[torch.Tensor], group: RANK_TYPES, tag: str = "" | |
) -> List[torch.Tensor]: | |
""" | |
Gather a list of tensors across from all machines. | |
Note that it currently only supports gather_dim = 0. | |
The input tensor is left unmodified. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh | |
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover | |
that information and perform collective algebraic optimization. Use other forms of input for that. | |
""" | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
group_size = c10d._get_group_size_by_name(group_name) | |
tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] | |
self, | |
group_size, | |
group_name, | |
) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] | |
self, | |
tag, | |
rankset, | |
group_size, | |
) | |
return list(map(_maybe_wrap_tensor, tensor_list)) | |
def reduce_scatter_tensor_coalesced( | |
inputs: List[torch.Tensor], | |
reduceOp: str, | |
scatter_dim: List[int], | |
group: RANK_TYPES, | |
tag: str = "", | |
) -> List[torch.Tensor]: | |
""" | |
Reduces a list of tensors across all machines in such a way that all get | |
the final result, then scatter the results to corresponding ranks. | |
The input tensors are left unmodified. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh | |
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover | |
that information and perform collective algebraic optimization. Use other forms of input for that. | |
""" | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
group_size = c10d._get_group_size_by_name(group_name) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
assert len(scatter_dim) == len(inputs) | |
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): | |
assert ( | |
tensor.size(dim) % group_size == 0 | |
), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" | |
if dim != 0: | |
tensor_list = torch.chunk(tensor, group_size, dim=dim) | |
inputs[idx] = torch.cat(tensor_list) | |
if native_funcol_enabled(): | |
tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined] | |
inputs, | |
reduceOp.lower(), | |
group_size, | |
group_name, # type: ignore[possibly-undefined] | |
) | |
else: | |
tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined] | |
inputs, | |
reduceOp, | |
tag, | |
rankset, # type: ignore[possibly-undefined] | |
group_size, | |
) | |
return list(map(_maybe_wrap_tensor, tensor_list)) | |
# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. | |
# Today, this maps 1:1 with "aten ops that are views". | |
def _is_view_op(tgt): | |
assert isinstance(tgt, torch._ops.OpOverload) | |
schema = tgt._schema | |
if len(schema.arguments) > 0: | |
first_arg = schema.arguments[0] | |
# check if op is a view | |
return first_arg.alias_info is not None and not first_arg.alias_info.is_write | |
def all_to_all_single( | |
self: torch.Tensor, | |
output_split_sizes: Optional[List[int]], | |
input_split_sizes: Optional[List[int]], | |
group: RANK_TYPES, | |
tag: str = "", | |
) -> torch.Tensor: | |
""" | |
Each process splits input tensor and then scatters the split list | |
to all processes in a group. Then concatenate the received tensors from all | |
the processes in the group and return single output tensor. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh | |
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover | |
that information and perform collective algebraic optimization. Use other forms of input for that. | |
""" | |
if output_split_sizes is not None: | |
assert all( | |
isinstance(size, (int, torch.SymInt)) for size in output_split_sizes | |
), output_split_sizes | |
if input_split_sizes is not None: | |
assert all( | |
isinstance(size, (int, torch.SymInt)) for size in input_split_sizes | |
), input_split_sizes | |
if native_funcol_enabled(): | |
group_name = _resolve_group_name(group, tag) | |
group_size = c10d._get_group_size_by_name(group_name) | |
if output_split_sizes is None or input_split_sizes is None: | |
assert output_split_sizes is None and input_split_sizes is None, ( | |
"output_split_sizes and input_split_sizes must either be " | |
"specified together or both set to None" | |
) | |
output_split_sizes = [self.shape[0] // group_size] * group_size | |
input_split_sizes = output_split_sizes | |
tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] | |
self, | |
output_split_sizes, | |
input_split_sizes, | |
group_name, | |
) | |
else: | |
tag, rankset, group_size = _expand_group(group, tag) | |
tensor = torch.ops.c10d_functional.all_to_all_single( # type: ignore[attr-defined] | |
self, | |
output_split_sizes, | |
input_split_sizes, | |
tag, | |
rankset, | |
group_size, | |
) | |
return _maybe_wrap_tensor(tensor) | |
def permute_tensor( | |
self: torch.Tensor, | |
src_dst: List[int], | |
group: RANK_TYPES, | |
tag: str = "", | |
) -> torch.Tensor: | |
""" | |
Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should | |
be defined such that src_dst[m] == n means m sends to n. | |
Group can be one of: | |
List[int]: ranks participating in the collective. | |
List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. | |
ProcessGroup: Will perform a collective using the ranks and tag of the PG. | |
DeviceMesh: Do a SPMD collective over all ranks of the mesh | |
(DeviceMesh, int): Do a MPMD collective over one | |
""" | |
t, rankset, group_size = _expand_group(group, tag) | |
local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size) | |
output_split_sizes = [0] * group_size | |
input_split_sizes = [0] * group_size | |
for src, dst in enumerate(src_dst): | |
if src == dist.get_rank(local_pg): | |
input_split_sizes[dst] = self.numel() | |
if dst == dist.get_rank(local_pg): | |
output_split_sizes[src] = self.numel() | |
return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag) | |
class AsyncCollectiveTensor(torch.Tensor): | |
r""" | |
A Tensor wrapper subclass that is used to trigger a call to wait | |
prior to first use of the underlying tensor. | |
Use it inside functional collective pytorch wrappers like the following: | |
def functional_collective(self, group, tag): | |
tag, rankset, group_size = _expand_group(group, tag) | |
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size) | |
return _maybe_wrap_tensor(tensor) | |
""" | |
elem: torch.Tensor | |
completed: bool | |
__slots__ = ["elem", "completed"] | |
def __new__(cls, elem: torch.Tensor): | |
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] | |
cls, | |
elem.size(), | |
strides=elem.stride(), | |
storage_offset=elem.storage_offset(), | |
dtype=elem.dtype, | |
layout=elem.layout, | |
device=elem.device, | |
requires_grad=False, | |
) | |
r.elem = elem | |
r.completed = False | |
return r | |
def __tensor_flatten__(self): | |
return ["elem"], None | |
def tolist(self): | |
self.trigger_wait() | |
return self.elem.tolist() | |
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): | |
assert meta is None | |
elem = inner_tensors["elem"] | |
return AsyncCollectiveTensor(elem) | |
def __repr__(self): | |
self.trigger_wait() | |
return f"AsyncCollectiveTensor({self.elem})" | |
def trigger_wait(self): | |
if not self.completed: | |
wait_tensor(self.elem) | |
self.completed = True | |
return self.elem | |
def wait(self) -> torch.Tensor: | |
wait_tensor(self.elem) | |
return self.elem | |
def _get_acs_underlying_tensor(self): | |
"""This method enables _functional_collectives_impl to test if a tensor is an ACS""" | |
return self.elem | |
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
if func == torch.ops.aten.view.default: | |
# Fast handle aten.view as a lot of view related op goes to aten.view | |
# eventually, this avoids pytree slowdown | |
res = func(args[0].elem, args[1]) | |
wrapper_res = AsyncCollectiveTensor(res) | |
_register_tensor_wrapper(wrapper_res) | |
return wrapper_res | |
is_view_op = _is_view_op(func) | |
def unwrap(e: AsyncCollectiveTensor): | |
# wait_tensor is idepotent and will do stream sync only once | |
if not is_view_op: | |
e.trigger_wait() | |
return e.elem | |
def wrap(e: torch.Tensor): | |
# wait_tensor is idepotent and will do stream sync only once | |
assert not isinstance(e, AsyncCollectiveTensor) | |
res = AsyncCollectiveTensor(e) | |
_register_tensor_wrapper(res) | |
return res | |
unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args) | |
unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs) | |
# we don't wrap the result as it doesn't need to be waited on. | |
out = func(*unwrapped_args, **unwrapped_kwargs) | |
# View ops dont require a sync, so we should re-wrap the outputs. | |
if is_view_op: | |
out = tree_map_only(torch.Tensor, wrap, out) | |
return out | |
def numpy(self): | |
return self.wait().numpy() | |
""" | |
Utils and infrastructure for tracing support | |
""" | |
def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]: | |
""" | |
_expand_group desugars the different RANK_TYPES types into a canonical format that is traceable. | |
By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside | |
torchdynamo and can still interoperate with processgroup objects or other untraceable forms. | |
""" | |
# had to define this hack _inside_ expand_group to avoid | |
# graph_break [('torch.* op returned non-Tensor int | |
# caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc) | |
if TYPE_CHECKING: | |
def cast_listlistint(x): | |
return cast(List[List[int]], x) | |
def cast_listint(x): | |
return cast(List[int], x) | |
else: | |
# fake cast op for use at runtime since dynamo doesn't support real cast | |
# also, dynamo didn't like encountering 'typing' objects () | |
# NotImplementedError: argument of type: <class 'typing._GenericAlias'> | |
def cast_listlistint(x): | |
return x | |
def cast_listint(x): | |
return x | |
rankset: List[int] | |
if isinstance(group, list): | |
if isinstance(group[0], list): | |
nested_list = cast_listlistint(group) | |
rankset = [] | |
group_size = -1 | |
for rs in nested_list: | |
rankset.extend(rs) | |
if group_size != -1 and group_size != len(rs): | |
raise ValueError( | |
f"group sizes must be identical found {group_size} and {len(rs)}" | |
) | |
group_size = len(rs) | |
else: | |
rankset = cast_listint(group) | |
group_size = len(rankset) | |
elif isinstance(group, dist.ProcessGroup): | |
rankset = dist.get_process_group_ranks(group) | |
group_size = len(rankset) | |
tag = tag or c10d._get_group_tag(group) | |
elif isinstance(group, DeviceMesh): | |
assert ( | |
group.ndim == 1 | |
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" | |
# TODO: it should run collective in the whole mesh instead of dim 0 | |
tag, rankset, _ = group._dim_group_infos[0] | |
group_size = len(rankset) | |
elif isinstance(group, tuple): | |
if ( | |
len(group) == 2 | |
and isinstance(group[0], DeviceMesh) | |
and isinstance(group[1], int) | |
): | |
dmesh = group[0] | |
dim = group[1] | |
tag, rankset, _ = dmesh._dim_group_infos[dim] | |
group_size = len(rankset) | |
else: | |
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") | |
else: | |
raise ValueError( | |
"Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)." | |
) | |
return (tag, rankset, group_size) | |
def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: | |
""" | |
Given group in RANK_TYPES, return the group name. | |
""" | |
# `tag` will be deprecated. See details in: | |
# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 | |
if isinstance(group, dist.ProcessGroup): | |
return group.group_name | |
elif isinstance(group, str): | |
return group | |
elif isinstance(group, DeviceMesh): | |
assert ( | |
group.ndim == 1 | |
), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" | |
return group._dim_group_infos[0][2] | |
elif isinstance(group, tuple): | |
if ( | |
len(group) == 2 | |
and isinstance(group[0], DeviceMesh) | |
and isinstance(group[1], int) | |
): | |
dmesh = group[0] | |
dim = group[1] | |
return dmesh._dim_group_infos[dim][2] | |
else: | |
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") | |
elif isinstance(group, list): | |
if not is_torchdynamo_compiling(): | |
warnings.warn( | |
"The combination of ranks + tag as process group " | |
"identifier has been deprecated. Please switch to " | |
"using ProcessGroup, DeviceMesh, or group name instead." | |
) | |
return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) | |
else: | |
raise ValueError(f"Unsupported group type: {type(group)}, {group}") | |
def _are_we_tracing() -> bool: | |
if is_torchdynamo_compiling(): | |
return True | |
# If functionalization is turned on, we are almost definitely compiling/tracing. | |
# (In particular, AOTAutograd traces a model once with functionalization on | |
# but proxy tracing turned of, so this is how we detect it). | |
if ( | |
torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) | |
is not None | |
): | |
return True | |
mode = get_innermost_proxy_mode() | |
if mode is None: | |
return False | |
return mode.tracer is not None | |
def _maybe_wrap_tensor(self) -> torch.Tensor: | |
if _are_we_tracing(): | |
return wait_tensor(self) | |
res = AsyncCollectiveTensor(self) | |
_register_tensor_wrapper(res) | |
return cast(torch.Tensor, res) | |
def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): | |
def mk_out_tensor(shard): | |
out_size = list(shard.size()) | |
out_size[0] *= group_size | |
out_tensor = shard.new_empty(out_size) | |
return out_tensor | |
return [mk_out_tensor(t) for t in self] | |
# We now register meta kernels to deal with tracing | |
def _broadcast_meta(self, *args): | |
return torch.empty_like(self) | |
def _all_reduce_meta(self, *args): | |
return torch.empty_like(self) | |
def _wait_tensor_meta(self, *args): | |
return torch.empty_like(self) | |
def _all_gather_into_tensor_meta(shard, tag, rankset, group_size): | |
out_size = list(shard.size()) | |
out_size[0] *= group_size | |
return shard.new_empty(out_size) | |
def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): | |
out_size = list(input.size()) | |
out_size[0] //= group_size | |
return input.new_empty(out_size) | |
def _all_reduce_coalesced_meta(self, *args): | |
return [torch.empty_like(t) for t in self] | |
def _all_reduce__meta(inp, *args): | |
return inp | |
def _broadcast__meta(inp, *args): | |
return inp | |
def _all_reduce_coalesced__meta(inputs, *args): | |
return inputs | |
def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size): | |
def mk_out_tensor(input): | |
out_size = list(input.size()) | |
out_size[0] //= group_size | |
out_tensor = input.new_empty(out_size) | |
return out_tensor | |
return [mk_out_tensor(t) for t in inputs] | |
# NB: We often say all_to_all has dynamic output size, but this is not | |
# technically true: instead, what typically happens is you manually | |
# communicate the output_split_sizes ahead of time (which is dynamic), | |
# but then you pass those sizes explicitly, and the all to all itself | |
# isn't dynamic, it just follows the specified output splits | |
def _all_to_all_single_meta( | |
input, output_split_sizes, input_split_sizes, *args, **kwargs | |
): | |
if output_split_sizes is None: | |
return input.new_empty(input.size()) | |
else: | |
for s in output_split_sizes: | |
torch._check_is_size(s) | |
out_size = list(input.size()) | |
out_size[0] = sum(output_split_sizes) | |
return input.new_empty(out_size) | |
def _all_gather_into_tensor_native_meta(input, group_size, group_name): | |
shape = list(input.size()) | |
shape[0] *= group_size | |
return input.new_empty(shape) | |
def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name): | |
return [ | |
_all_gather_into_tensor_native_meta(input, group_size, group_name) | |
for input in inputs | |
] | |
def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): | |
shape = list(inp.size()) | |
shape[0] //= group_size | |
return inp.new_empty(shape) | |
def _reduce_scatter_tensor_coalesced_native_meta( | |
inputs, reduce_op, group_size, group_name | |
): | |
return [ | |
_reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name) | |
for inp in inputs | |
] | |
def _register_ops(): | |
ops_defs = [ | |
"broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", | |
"all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", | |
"all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", | |
"wait_tensor(Tensor self) -> Tensor", | |
"all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", | |
"all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", | |
"reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", | |
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", | |
"all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 | |
] | |
my_module = sys.modules[__name__] | |
for op_def in ops_defs: | |
op_name = op_def[0 : op_def.index("(")] | |
backend_impl = getattr(fun_col_impl, f"_{op_name}") | |
meta_impl = getattr(my_module, f"_{op_name}_meta") | |
c10_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) | |
c10_lib_impl.impl(op_name, backend_impl, "CompositeExplicitAutograd") | |
impl_abstract(f"c10d_functional::{op_name}")(meta_impl) | |
if not torch._running_with_deploy(): | |
# Library MUST be defined at module scope or it doesn't work | |
# Creating a "DEF" Library always crashes torch::deploy so we create our Library instances here | |
# guarded against running inside it | |
c10_lib = torch.library.Library("c10d_functional", "DEF") | |
c10_lib_impl = torch.library.Library("c10d_functional", "IMPL") | |
_register_ops() | |
_c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL") | |
_c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") | |
_c10_lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") | |
_c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") | |
_c10_lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") | |
_c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") | |
_c10_lib_impl.impl( | |
"all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta" | |
) | |
_c10_lib_impl.impl( | |
"all_gather_into_tensor_coalesced", | |
_all_gather_into_tensor_coalesced_native_meta, | |
"Meta", | |
) | |
_c10_lib_impl.impl( | |
"reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta" | |
) | |
_c10_lib_impl.impl( | |
"reduce_scatter_tensor_coalesced", | |
_reduce_scatter_tensor_coalesced_native_meta, | |
"Meta", | |
) | |
_c10_lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") | |
_c10_lib_impl.impl("broadcast", _broadcast_meta, "Meta") | |
_c10_lib_impl.impl("broadcast_", _broadcast__meta, "Meta") | |
else: | |
warnings.warn( | |
"PyTorch Distributed functional collectives do not work with torch::deploy." | |
) | |
""" | |
Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into | |
functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph. | |
We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via | |
the mapping dict below. | |
These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from | |
""" | |
def all_gather_tensor_inplace( | |
output_tensor: torch.Tensor, | |
input_tensor: torch.Tensor, | |
group, # TODO add a type, | |
async_op: bool = False, | |
tag: str = "", | |
gather_dim: int = 0, | |
): | |
assert ( | |
not async_op | |
), "Can't remap async version of inplace op to functional collective" | |
return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) | |
def reduce_scatter_tensor_inplace( | |
output: torch.Tensor, | |
input: torch.Tensor, | |
op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok? | |
group=None, # TODO add a type | |
async_op: bool = False, | |
scatter_dim: int = 0, | |
tag: str = "", | |
): | |
assert ( | |
not async_op | |
), "Can't remap async version of inplace op to functional collective" | |
return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) | |
REDUCE_OP_TO_STR = { | |
dist.ReduceOp.SUM: "sum", | |
dist.ReduceOp.AVG: "avg", | |
dist.ReduceOp.PRODUCT: "product", | |
dist.ReduceOp.MIN: "min", | |
dist.ReduceOp.MAX: "max", | |
dist.ReduceOp.BAND: "band", | |
dist.ReduceOp.BOR: "bor", | |
dist.ReduceOp.BXOR: "bxor", | |
} | |
def all_reduce_inplace( | |
tensor: torch.Tensor, | |
op: str = "sum", | |
group=None, | |
async_op: bool = False, | |
tag: str = "", | |
): | |
assert ( | |
not async_op | |
), "Can't remap async version of inplace op to functional collective" | |
return tensor.copy_(all_reduce(tensor, op, group, tag)) | |
def all_to_all_inplace( | |
output: torch.Tensor, | |
input: torch.Tensor, | |
output_split_sizes=None, | |
input_split_sizes=None, | |
group=None, | |
async_op=False, | |
tag: str = "", | |
): | |
assert ( | |
not async_op | |
), "Can't remap async version of inplace op to functional collective" | |
return output.copy_( | |
all_to_all_single(input, output_split_sizes, input_split_sizes, group, tag) | |
) | |
def all_gather_inplace( | |
tensor_list: List[torch.Tensor], | |
tensor: torch.Tensor, | |
group=None, | |
async_op=False, | |
tag: str = "", | |
): | |
assert ( | |
not async_op | |
), "Can't remap async version of inplace op to functional collective" | |
assert all( | |
t.size(0) == tensor.size(0) for t in tensor_list | |
), "Remapping variable size all_gather is not yet supported" | |
output = all_gather_tensor(tensor, 0, group, tag) | |
# Use aten.slice instead of aten.split because the latter causes | |
# tensor.shape(0) to be unnecessarily baked in when it's a SymInt. | |
output_splits = [] | |
offset = 0 | |
for t in tensor_list: | |
output_splits.append(output[offset : offset + t.size(0)]) | |
offset += t.size(0) | |
for dst, src in zip(tensor_list, output_splits): | |
dst.copy_(src) | |
return tensor_list | |
from torch.distributed.distributed_c10d import ( | |
_all_gather_base as legacy_all_gather_base, | |
_reduce_scatter_base as legacy_reduce_scatter_base, | |
all_gather as legacy_all_gather, | |
all_gather_into_tensor as legacy_allgather, | |
all_reduce as legacy_allreduce, | |
all_to_all_single as legacy_all_to_all_single, | |
reduce_scatter_tensor as legacy_reducescatter, | |
) | |
# This dict should contain sets of functions that dynamo is allowed to remap. | |
# Functions in this set should accept the same args/kwargs 1:1 as their mapping. | |
traceable_collective_remaps = { | |
legacy_allgather: all_gather_tensor_inplace, | |
legacy_reducescatter: reduce_scatter_tensor_inplace, | |
legacy_allreduce: all_reduce_inplace, | |
legacy_all_to_all_single: all_to_all_inplace, | |
legacy_all_gather: all_gather_inplace, | |
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, | |
legacy_all_gather_base: all_gather_tensor_inplace, | |
} | |