Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from torch.distributed.rpc import is_available | |
from mmengine.dist import is_main_process | |
from mmengine.utils import digit_version | |
from mmengine.utils.dl_utils import TORCH_VERSION | |
try: | |
from torch.distributed.optim import \ | |
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer | |
except ImportError: | |
_ZeroRedundancyOptimizer = object | |
from .builder import OPTIMIZERS | |
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): | |
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a | |
optimizer type as string. | |
This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its | |
states across ranks in the group as described by ZeRO_. The local optimizer | |
instance in each rank is only responsible for updating approximately | |
``1 / world_size`` parameters and hence only needs to keep | |
``1 / world_size`` optimizer states. After parameters are updated locally, | |
each rank will broadcast its parameters to all other peers to keep all | |
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used | |
in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to | |
reduce per-rank peak memory consumption. | |
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number | |
of parameters at each rank. Each parameter belongs to a single rank and is | |
not divided among ranks. The partition is arbitrary and might not match the | |
the parameter registration or usage order. | |
Warnings: | |
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. | |
Warnings: | |
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param | |
groups. | |
Args: | |
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s | |
or :class:`dict` s giving all parameters, which will be sharded | |
across ranks. | |
optimizer_type (str): the string of the local optimizer class. | |
.. _ZeRO: https://arxiv.org/abs/1910.02054 | |
""" | |
def __init__(self, params, optimizer_type: str, **kwargs): | |
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( | |
'`torch.distributed.optim.ZeroReundancyOptimizer` is only ' | |
'available when pytorch version >= 1.8.') | |
assert is_available(), 'torch.distributed.rpc is not available.' | |
# Avoid the generator becoming empty after the following check | |
params = list(params) | |
assert ( | |
all(isinstance(p, torch.Tensor) for p in params) | |
or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), ( | |
'PyTorch ZeroRedundancyOptimizer started to support param ' | |
'groups since 1.12.0. Please update your pytorch version to ' | |
'enable this feature, or disable param groups by deleting ' | |
'`paramwise_cfg` filed in config file.') | |
optimizer_class = getattr(torch.optim, optimizer_type) | |
# TODO: Register a DDP communication hook for `overlap_with_ddp=True`. | |
# Currently only `overlap_with_ddp=False` is supported. For more | |
# details, please refer to the pytorch's official documentation. | |
super().__init__(params, optimizer_class, **kwargs) | |
def state_dict(self): | |
"""Consolidate `state_dict`s from ranks to save the `state_dict`.""" | |
self.consolidate_state_dict() | |
state_dict = super().state_dict() if is_main_process() else dict() | |
return state_dict | |