Spaces:
Running
Running
from contextlib import contextmanager | |
from typing import Optional | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
from torch.distributed import distributed_c10d | |
from torch.distributed._shard.sharded_tensor import ( | |
ShardedTensor, | |
) | |
from .sharding_spec import ( | |
ShardingSpec, | |
ChunkShardingSpec | |
) | |
from .sharding_plan import ( | |
ShardingPlan | |
) | |
from .sharder import Sharder | |
def _shard_tensor( | |
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None | |
) -> ShardedTensor: | |
""" | |
Given a :class:`torch.Tensor`, it shards that tensor according to the provided | |
``sharding_spec``. ``src_rank`` denotes the source rank which would be | |
used as the ground truth of the data which would be scattered as shards | |
across the rest of the ranks. | |
Args: | |
tensor (:class:`torch.Tensor`): Tensor needs to be sharded. | |
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification | |
describing how to shard the Tensor. | |
Keyword args: | |
src_rank (int, optional): The source rank which is used as the ground truth of | |
the data for the parameter that would be sharded and scattered | |
across the rest of the ranks. | |
Default: 0. | |
process_group (ProcessGroup, optional): The process group to work on. If None, | |
the default process group will be used. | |
Returns: | |
A :class:`ShardedTensor` sharded from the given tensor. | |
.. warning:: | |
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is | |
currently supported as the ``sharding_spec``. | |
""" | |
if not tensor.is_contiguous(): | |
raise ValueError('input tensor is not a contiguous Tensor') | |
pg = process_group if process_group is not None else distributed_c10d._get_default_group() | |
world_size = dist.get_world_size(pg) | |
current_rank = dist.get_rank(pg) | |
# Validate src_rank and sharding_spec are same across all ranks. | |
gathered_list = [None] * world_size | |
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) | |
for idx, entry in enumerate(gathered_list): | |
if src_rank != entry[0]: # type: ignore[index] | |
raise ValueError( | |
f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index] | |
f'match with src_rank={entry[0]} on rank: {idx}') | |
if sharding_spec != entry[1]: # type: ignore[index] | |
raise ValueError( | |
f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index] | |
f'match with sharding_spec={entry[1]} on rank: {idx}') | |
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group) | |
return st | |
def shard_parameter( | |
module: torch.nn.Module, | |
param_name: str, | |
sharding_spec: ShardingSpec, | |
src_rank=0, | |
process_group=None): | |
""" | |
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that | |
module, it shards that parameter according to the provided | |
``sharding_spec``. ``src_rank`` denotes the source rank which would be | |
used as the ground truth of the data which would be scattered as shards | |
across the rest of the ranks. | |
This method replaces ``module.param_name`` with a | |
:class:`torch.distributed._sharded_tensor.ShardedTensor` | |
Args: | |
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. | |
param_name (str): Name of the parameter of ``module`` that needs to be sharded. | |
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification | |
describing how to shard the Tensor. | |
Keyword args: | |
src_rank (int, optional): The source rank which is used as the ground truth of | |
the data for the parameter that would be sharded and scattered | |
across the rest of the ranks. | |
Default: 0. | |
process_group (ProcessGroup, optional): The process group to work on. If None, | |
the default process group will be used. | |
.. warning:: | |
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is | |
currently supported as the ``sharding_spec``. | |
""" | |
# Perform some validation first. | |
if not hasattr(module, param_name): | |
raise AttributeError(f'{module._get_name()} has no attribute `{param_name}`') | |
tensor = getattr(module, param_name) | |
if not isinstance(tensor, torch.Tensor): | |
raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') | |
if not tensor.is_contiguous(): | |
raise ValueError(f'param: {param_name} is not a contiguous Tensor') | |
st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) | |
# Replace param with ShardedTensor. | |
module.register_parameter(param_name, nn.Parameter(st)) | |
# Tracks the current process group in the load context manager. | |
_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None | |
def load_with_process_group(process_group): | |
""" | |
Context manager to set the process group with which to load a ShardedTensor. | |
""" | |
global _CURRENT_PROCESS_GROUP | |
if _CURRENT_PROCESS_GROUP is not None: | |
raise RuntimeError( | |
'ProcessGroup already set by previous "load_with_process_group" ' | |
'context manager') | |
_CURRENT_PROCESS_GROUP = process_group | |
try: | |
yield process_group | |
finally: | |
_CURRENT_PROCESS_GROUP = None | |
def _get_current_process_group(): | |
""" | |
Retrieves the current process group set by ``load_with_process_group``. | |
If not set, it just returns the default group. | |
""" | |
global _CURRENT_PROCESS_GROUP | |
if _CURRENT_PROCESS_GROUP is None: | |
return distributed_c10d._get_default_group() | |
else: | |
return _CURRENT_PROCESS_GROUP | |
def _reshard_output( | |
module: torch.nn.Module, | |
resharding_spec: ShardingSpec) -> torch.nn.Module: | |
""" | |
Hook a module with output resharding in the forward pass according | |
to the given ``resharding_spec``. | |
Args: | |
module (:class:`torch.nn.Module`): Module whose output needs to be resharded. | |
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): | |
The specification describing how the output of the module will be resharded. | |
Returns: | |
A :class:`torch.nn.Module` object with reshard API hooked. | |
""" | |
def hook_func(_module, _input, output): | |
if isinstance(output, ShardedTensor): | |
return output.reshard(resharding_spec) | |
return output | |
module.register_forward_hook(hook_func) | |
return module | |
def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module: | |
""" | |
Hook a module with local shards collection in the forward pass. | |
This API is typically used to convert a sharded representation back to data parallel | |
representation. In particular, it returns the local tensor for this Shard. If the | |
size along the sharding dimension for the local tensor is 1, this dimension is removed | |
from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically | |
a local Tensor of size [16] across each rank and not [1, 16] across each rank. | |
Args: | |
module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the | |
local tensor value needs to be returned. | |
Returns: | |
A :class:`torch.nn.Module` object with collection API hooked. | |
""" | |
def hook_func(_module, _input, output): | |
if isinstance(output, ShardedTensor): | |
local_tensor = output.local_tensor() | |
# Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec | |
sharding_spec = output._sharding_spec | |
if isinstance(sharding_spec, ChunkShardingSpec) \ | |
and local_tensor.size(sharding_spec.dim) == 1: # type: ignore[attr-defined, arg-type] | |
local_tensor = local_tensor.squeeze( | |
output._sharding_spec.dim # type: ignore[attr-defined] | |
) | |
return local_tensor | |
module.register_forward_hook(hook_func) | |
return module | |
def shard_module( | |
module: nn.Module, | |
plan: ShardingPlan, | |
src_rank=0, | |
process_group=None | |
): | |
""" | |
Shards a given module according to the provided sharding `plan`. This method | |
first shards all the parameters according to the given sharding `plan`. Then if | |
`output_plan` and `return_local_tensor` are specified in the sharding `plan`, it | |
will tag the output of modules according `output_plan`, convert the module's | |
output back to data parallel according to `return_local_tensor`. | |
Needs to be called on all ranks in an SPMD fashion. | |
Args: | |
module (:class:`torch.nn.Module`): The module to apply sharding to | |
plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`): | |
The ShardingPlan which specified param name to ShardingSpec to apply to | |
each parameter. | |
Keyword args: | |
src_rank (int, optional): The source rank which is used as the ground truth of | |
the data for the module that would be sharded and scattered across the rest | |
of the ranks. | |
Default: 0. | |
process_group (ProcessGroup, optional): The process group to work on. If None, | |
the default process group will be used. | |
""" | |
# record Sharder paths for sanity check on the plan to ensure items in the plan | |
# does not conflict with the submodule tree that the Sharder is working with | |
sharder_paths = [] | |
for name, spec in plan.plan.items(): | |
if isinstance(spec, Sharder): | |
sharder_paths.append(name) | |
# shard the parameter according to the ShardingPlan | |
for name, spec in plan.plan.items(): | |
if isinstance(spec, ShardingSpec): | |
# if found a sharding spec, try to shard the parameter | |
module_path, _, param_name = name.rpartition(".") | |
for sharder_path in sharder_paths: | |
if module_path.startswith(sharder_path): | |
raise RuntimeError(f"ShardingPlan is in-valid, trying to shard a parameter: {name}," | |
f" but there's already a Sharder entry for module {sharder_path}," | |
f" parameter sharding should not conflict with the submodule tree" | |
f" that a Sharder is working with!") | |
mod = module.get_submodule(module_path) | |
shard_parameter( | |
mod, | |
param_name, | |
spec, | |
src_rank=src_rank, | |
process_group=process_group | |
) | |
elif isinstance(spec, Sharder): | |
parent_mod_path, _, mod_name = name.rpartition(".") | |
if name == "": | |
raise KeyError("Module path must not be empty for custom sharder!") | |
mod = module.get_submodule(name) | |
parent_mod = module.get_submodule(parent_mod_path) | |
sharded_mod = spec.shard(mod) | |
# swap this submodule with the sharded module | |
parent_mod.mod_name = sharded_mod | |
else: | |
raise TypeError(f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'") | |
# reshard output if there's an entry in `reshard_output` for this module | |
if plan.output_plan is not None: | |
for module_path, output_spec in plan.output_plan.items(): | |
if isinstance(output_spec, ShardingSpec): | |
mod = module.get_submodule(module_path) | |
_reshard_output(mod, output_spec) | |
else: | |
raise TypeError(f"Only `ShardingSpec` is supported as output_plan for '{module_path}'") | |
# convert the output back to data parallel for the modules appears in | |
# `return_local_tensor` of the plan, we will call `_collect_local_shard` | |
# to collect the local tensor for output of modules | |
if plan.return_local_tensor is not None: | |
for module_path in plan.return_local_tensor: | |
mod = module.get_submodule(module_path) | |
_collect_local_shard(mod) | |