Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
import copy | |
from abc import ABC, abstractmethod | |
from typing import ( | |
Any, | |
Callable, | |
cast, | |
Dict, | |
Generator, | |
Iterable, | |
Optional, | |
Sequence, | |
Set, | |
Tuple, | |
Type, | |
Union, | |
) | |
import torch.nn as nn | |
__all__ = [ | |
"always_wrap_policy", | |
"lambda_auto_wrap_policy", | |
"transformer_auto_wrap_policy", | |
"size_based_auto_wrap_policy", | |
"enable_wrap", | |
"wrap", | |
"CustomPolicy", | |
"ModuleWrapPolicy", | |
] | |
# NOTE: We intentionally keep this function simple and isolate the complexity | |
# to `fn` to enable using this function generically. We may move this to a | |
# non-FSDP-specific folder and/or make it public in the future. | |
def _post_order_apply( | |
root_module: nn.Module, | |
fn: Callable[[nn.Module], Optional[nn.Module]], | |
): | |
""" | |
This applies ``fn`` to every module in the module tree of ``root_module`` | |
following a post-order traversal. If ``fn`` returns an :class:`nn.Module`, | |
then this replaces the original module with the newly returned one in the | |
tree. Otherwise, ``fn`` should return ``None``, in which case the module is | |
not changed. | |
""" | |
# Track visited modules to avoid visiting shared modules multiple times | |
visited_modules: Set[nn.Module] = {root_module} | |
def _post_order_apply_inner( | |
module: nn.Module, | |
module_name: str, | |
parent_module: Optional[nn.Module], | |
): | |
for child_module_name, child_module in module.named_children(): | |
if child_module not in visited_modules: | |
visited_modules.add(child_module) | |
_post_order_apply_inner(child_module, child_module_name, module) | |
optional_module = fn(module) | |
if optional_module is not None: | |
assert isinstance(parent_module, nn.Module), ( | |
"Non-root modules should have their parent module set but got " | |
f"{parent_module} for {module}" | |
) | |
assert module_name, ( | |
"Non-root modules should have their module name set but got " | |
f"an empty module name for {module}" | |
) | |
assert isinstance( | |
optional_module, nn.Module | |
), f"fn should return None or an nn.Module but got {optional_module}" | |
setattr(parent_module, module_name, optional_module) | |
_post_order_apply_inner(root_module, "", None) | |
def _construct_wrap_fn( | |
root_module: nn.Module, | |
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], | |
fsdp_fn: Callable, | |
) -> Callable[[nn.Module], Optional[nn.Module]]: | |
""" | |
This constructs the "wrap" function to pass to :func:`_post_order_apply` | |
based on ``target_module_to_kwargs``, which should be constructed from the | |
wrapping policy. | |
""" | |
def fn(module: nn.Module) -> Optional[nn.Module]: | |
# Explicitly avoid wrapping the root module since for FSDP, it is | |
# handled by the caller | |
if module in target_module_to_kwargs and module is not root_module: | |
kwargs = target_module_to_kwargs[module] | |
return fsdp_fn(module, **kwargs) | |
return None | |
return fn | |
def _run_mixed_precision_override_policy( | |
root_module: nn.Module, | |
module_classes: Iterable[Type[nn.Module]], | |
ignored_modules: Set[nn.Module], | |
root_kwargs: Dict[str, Any], | |
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], | |
): | |
module_classes_tuple = tuple(set(module_classes)) | |
for module in root_module.modules(): | |
if module in ignored_modules: | |
continue | |
elif isinstance(module, module_classes_tuple): | |
# This policy overrides any existing policy | |
if module not in target_module_to_kwargs: | |
# Only inherit from the root kwargs if not already specified | |
target_module_to_kwargs[module] = root_kwargs | |
target_module_to_kwargs[module]["mixed_precision"] = None | |
return target_module_to_kwargs | |
def always_wrap_policy(*args, **kwargs) -> bool: | |
""" | |
A simple recursive wrap policy that always returns ``True``. This means | |
that every submodule is wrapped by the wrapper class in | |
:func:`_recursive_wrap`. | |
""" | |
return True | |
class _Policy(ABC): | |
""" | |
This defines an abstract base class that represents a policy for applying | |
a module-level API. | |
""" | |
def _run_policy( | |
self, | |
root_module: nn.Module, | |
ignored_modules: Set[nn.Module], | |
root_kwargs: Dict[str, Any], | |
) -> Dict[nn.Module, Dict[str, Any]]: | |
""" | |
This should return a dict ``target_module_to_kwargs`` that maps from | |
each target module to wrap to its kwargs. | |
""" | |
... | |
def _module_wrap_policy( | |
module: nn.Module, | |
recurse: bool, | |
nonwrapped_numel: int, | |
module_classes: Set[Type[nn.Module]], | |
) -> bool: | |
""" | |
This auto wrap policy wraps every module that is an instance of any type in | |
``module_classes`` as its own FSDP instance. The root module given by | |
``module`` is always wrapped as an FSDP instance regardless. Since the | |
wrapping proceeds bottom up, each FSDP instance manages the parameters in | |
its subtree excluding any already managed by a child FSDP instance. | |
Args: | |
module (nn.Module): Current module being considered. | |
recurse (bool): If ``False``, then this function must decide whether | |
``module`` should be wrapped as an FSDP instance or not. If | |
``True``, then the function is still recursing down the module | |
tree as a part of the DFS. | |
nonwrapped_numel (int): Parameter numel not yet wrapped. | |
module_classes (Set[Type[nn.Module]]): Set of module classes that are | |
wrapped as FSDP instances. | |
Returns: | |
``True`` if ``recurse=True``, and whether ``module`` should be wrapped | |
if ``recurse=False``. | |
""" | |
if recurse: | |
return True # always recurse | |
return isinstance(module, tuple(module_classes)) | |
class ModuleWrapPolicy(_Policy): | |
""" | |
This policy applies to every module of the specified module classes, | |
passing in the kwargs given to the root. | |
""" | |
def __init__(self, module_classes: Iterable[Type[nn.Module]]): | |
module_classes_set = set(module_classes) | |
self._module_classes = module_classes_set | |
self._module_classes_str = str(module_classes_set) | |
def _run_policy( | |
self, | |
root_module: nn.Module, | |
ignored_modules: Set[nn.Module], | |
root_kwargs: Dict[str, Any], | |
) -> Dict[nn.Module, Dict[str, Any]]: | |
module_classes = tuple(self._module_classes) | |
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} | |
for module in root_module.modules(): | |
if module in ignored_modules: | |
continue | |
elif isinstance(module, module_classes): | |
# Shallow copy to avoid coupling changes across modules | |
target_module_to_kwargs[module] = copy.copy(root_kwargs) | |
return target_module_to_kwargs | |
def __call__(self, module, recurse, *args, **kwargs): | |
# nonwrapped_numel is not used. | |
return _module_wrap_policy( | |
module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes | |
) | |
def __repr__(self) -> str: | |
return super().__repr__() + f"({self._module_classes_str})" | |
class CustomPolicy(_Policy): | |
""" | |
This policy takes in a lambda function that maps a given ``nn.Module`` to | |
either ``False``, ``True``, or a kwarg dictionary. | |
- If the function returns ``False`` or an empty dictionary, then the module | |
does not have the API applied. | |
- If the function returns ``True``, then the module has the API applied | |
with the root's kwargs. | |
- If the function returns a non-empty dictionary, then the module has the | |
API applied, and the dictionary overrides the root's kwargs. | |
Example:: | |
>>> # xdoctest: +SKIP("undefined variables") | |
>>> model = init_transformer_model(...) | |
>>> def lambda_fn(module: nn.Module): | |
>>> if module is model.lm_head: | |
>>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} | |
>>> elif isinstance(module, TransformerBlock): | |
>>> return True | |
>>> return False | |
>>> policy = CustomPolicy(lambda_fn) | |
>>> fsdp_model = FSDP(model, auto_wrap_policy=policy) | |
""" | |
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]): | |
self._lambda_fn = lambda_fn | |
def _run_policy( | |
self, | |
root_module: nn.Module, | |
ignored_modules: Set[nn.Module], | |
root_kwargs: Dict[str, Any], | |
) -> Dict[nn.Module, Dict[str, Any]]: | |
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} | |
for module in root_module.modules(): | |
if module in ignored_modules: | |
continue | |
res = self._lambda_fn(module) | |
if not isinstance(res, (dict, bool)): | |
raise ValueError( | |
"The lambda_fn passed to CustomPolicy should return " | |
f"False/True or a kwarg dict, but it returned {res}" | |
) | |
if not res: | |
continue | |
kwargs = copy.copy(root_kwargs) | |
if isinstance(res, dict): | |
# Override the root kwargs with the ones specified by the | |
# lambda function | |
kwargs.update(res) | |
target_module_to_kwargs[module] = kwargs | |
return target_module_to_kwargs | |
def lambda_auto_wrap_policy( | |
module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable | |
) -> bool: | |
""" | |
A convenient auto wrap policy to wrap submodules based on an arbitrary user | |
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as | |
a `wrapper_cls` unit. | |
Return if a module should be wrapped during auto wrapping. | |
The first three parameters are required by :func:`_recursive_wrap`. | |
Args: | |
module (nn.Module): Current module being considered. | |
recurse (bool): If ``False``, then this function must decide whether | |
``module`` should be wrapped as an FSDP instance or not. If | |
``True``, then the function is still recursing down the module | |
tree as a part of the DFS. | |
nonwrapped_numel (int): Parameter numel not yet wrapped. | |
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then | |
this module will be wrapped. | |
""" | |
if recurse: | |
return True # always recurse | |
return lambda_fn(module) | |
def transformer_auto_wrap_policy( | |
module: nn.Module, | |
recurse: bool, | |
nonwrapped_numel: int, | |
transformer_layer_cls: Set[Type[nn.Module]], | |
) -> bool: | |
""" | |
See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the | |
same as ``module_classes``. Note that shared parameters must be wrapped in | |
the same FSDP instance, so this auto wrap policy can help wrap shared | |
embeddings into the same FSDP instance for transformer models. | |
""" | |
return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) | |
def _wrap_module_cls_individually( | |
module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs | |
): | |
if recurse: | |
# always recurse | |
return True | |
else: | |
# if not recursing, decide whether we should wrap based on whether the type of module | |
# is in `module_classes`. | |
return isinstance(module, tuple(module_classes)) | |
def _or_policy( | |
module: nn.Module, | |
recurse: bool, | |
nonwrapped_numel: int, | |
policies, | |
) -> bool: | |
""" | |
A policy that wraps ``module`` if any policy in the passed in iterable of | |
``policies`` returns ``True``. | |
""" | |
return any( | |
policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) | |
for policy in policies | |
) | |
def size_based_auto_wrap_policy( | |
module: nn.Module, | |
recurse: bool, | |
nonwrapped_numel: int, | |
# Additional custom arguments | |
min_num_params: int = int(1e8), | |
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, | |
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, | |
) -> bool: | |
""" | |
A size-based auto wrap policy. | |
Args: | |
module (nn.Module): Current module being considered. | |
recurse (bool): If ``False``, then this function must decide whether | |
``module`` should be wrapped as an FSDP instance or not. If | |
``True``, then the function is still recursing down the module | |
tree as a part of the DFS. | |
nonwrapped_numel (int): Parameter numel not yet wrapped. | |
min_num_params (int): Customizable policy input that controls the size | |
threshold over which a module is ready to be wrapped. This is in | |
units of numel. | |
force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep | |
as leaves, i.e. their children will never be wrapped. | |
exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be | |
excluded in wrapping. | |
Returns: | |
Whether ``module`` should be wrapped. | |
""" | |
force_leaf_modules = ( | |
size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] | |
if force_leaf_modules is None | |
else force_leaf_modules | |
) | |
exclude_wrap_modules = ( | |
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] | |
if exclude_wrap_modules is None | |
else exclude_wrap_modules | |
) | |
# Keep the argument `min_num_params` for BC for now, but it represents the | |
# minimum non-wrapped *numel* before triggering a wrapping | |
min_nonwrapped_numel = min_num_params | |
is_large = nonwrapped_numel >= min_nonwrapped_numel | |
if recurse: | |
# We should recurse if the module is big enough but not in force_leaf_modules list. | |
return is_large and not isinstance(module, tuple(force_leaf_modules)) | |
else: | |
# If we are not recursing, determine if we should wrap. | |
return is_large and not isinstance(module, tuple(exclude_wrap_modules)) | |
# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. | |
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined] | |
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined] | |
def enable_wrap( | |
*, wrapper_cls: Any, **wrapper_kwargs: Any | |
) -> Generator[None, None, None]: | |
""" | |
Context manager to wrap modules using a wrapper. | |
Useful for when you'd like to apply the same configuration arguments to all | |
child modules that you wrap. A particularly important use case is wrapping | |
large layers so that they get sharded (in-place) during initialization, to | |
avoid running out of system memory. Large layers can indicate that they | |
should be sharded via the ``wrap`` annotation and this context manager can | |
provide the exact configuration for these nested instances. | |
Usage:: | |
with enable_wrap(wrapper_cls, **params): | |
# Wraps layer in FSDP by default if within context | |
self.l1 = wrap(torch.nn.Linear(5, 5)) | |
Args: | |
wrapper_cls: | |
Class that `wrap` annotation will `wrap` modules with, such as | |
`FullyShardedDataParallel`. | |
**wrapper_kwargs: | |
Configuration settings that will be passed to all ``wrap`` | |
instances inside the context | |
""" | |
kwargs = { | |
"wrapper_cls": wrapper_cls, | |
**wrapper_kwargs, | |
} | |
with _ConfigAutoWrap(**kwargs): | |
yield | |
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: | |
""" | |
Annotate that a module should be wrapped. Annotated modules will only be | |
wrapped if inside of an :func:`enable_wrap` context manager. This allows | |
a module to be initialized both with and without a wrapper without code | |
change. | |
The class that this function wraps the passed in ``nn.Module`` with is the | |
passed in ``wrapper_cls`` argument into ``enable_wrap``. Both | |
``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct | |
the ``wrapper_cls`` instance. In the case of duplicate kwargs in | |
``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be | |
respected. | |
Usage:: | |
with enable_wrap(wrapper_cls=FSDP, **fsdp_config): | |
# Wraps layer in FSDP by default if within context | |
self.l1 = wrap(torch.nn.Linear(5, 5)) | |
Args: | |
module (nn.Module): module to wrap (if in :func:`enable_wrap` context) | |
**wrap_overrides: configuration overrides that will take priority over | |
the values provided by the :func:`enable_wrap` context | |
""" | |
if _ConfigAutoWrap.in_autowrap_context: | |
assert _ConfigAutoWrap.wrapper_cls is not None | |
wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} | |
return _wrap( | |
module, | |
_ConfigAutoWrap.wrapper_cls, | |
**wrap_overrides, | |
) | |
return module | |
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: | |
assert wrapper_cls is not None | |
if hasattr(module, "_wrap_overrides"): | |
# If module has a _wrap_overrides attribute, we force overriding the | |
# FSDP config with these attributes for this module. Currently this | |
# is only used to disable mixed precision for BatchNorm when | |
# auto_wrapping. | |
overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type] | |
return wrapper_cls(module, **overrides) | |
return wrapper_cls(module, **kwargs) | |
def _recursive_wrap( | |
module: nn.Module, | |
auto_wrap_policy: Callable, | |
wrapper_cls: Callable, | |
ignored_modules: Set[nn.Module], | |
ignored_params: Set[nn.Parameter], | |
only_wrap_children: bool = False, | |
**kwargs: Any, | |
) -> Tuple[nn.Module, int]: | |
""" | |
Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns | |
``True`` with ``wrapper_cls``. | |
Args: | |
module (nn.Module): Module to recursively wrap. | |
auto_wrap_policy (Callable): A callable representing a policy that | |
determines which modules to recursively wrap with ``wrapper_cls``. | |
ignored_modules (Set[torch.nn.Module]): Modules to ignore when | |
wrapping. | |
ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when | |
wrapping; these should be the parameters contained in the modules | |
in ``ignored_modules``. | |
Returns: | |
(nn.Module, int): | |
``module`` after wrapping and the numel recursively wrapped. | |
""" | |
assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." | |
assert wrapper_cls is not None, "Must specify wrapper_cls" | |
# Make sure no child is already wrapped. | |
for _, child in module.named_modules(): | |
if child in ignored_modules: | |
continue | |
try: | |
assert not isinstance(child, cast(type, wrapper_cls)) | |
except TypeError: | |
# wrapper_cls is a function as opposed to a class type, just bypass above check. | |
pass | |
# We count all params, assuming none of them are already wrapped. | |
nonwrapped_numel = sum( | |
p.numel() for p in module.parameters() if p not in ignored_params | |
) | |
assert auto_wrap_policy is not None | |
if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): | |
total_wrapped_numel = 0 | |
# Iterate through the children, recursively wrap if necessary | |
for name, child in module.named_children(): | |
if child in ignored_modules: | |
continue | |
wrapped_child, num_wrapped_params = _recursive_wrap( | |
module=child, | |
auto_wrap_policy=auto_wrap_policy, | |
wrapper_cls=wrapper_cls, | |
ignored_modules=ignored_modules, | |
ignored_params=ignored_params, | |
**kwargs, | |
) | |
setattr(module, name, wrapped_child) | |
# Keep track of how many parameters have been wrapped | |
total_wrapped_numel += num_wrapped_params | |
# decide if we need to wrap the current module, | |
# since the left over parameters exceed the number of params to wrap | |
remainder = nonwrapped_numel - total_wrapped_numel | |
if not only_wrap_children and auto_wrap_policy( | |
module=module, recurse=False, nonwrapped_numel=remainder | |
): | |
# Leaf node or final wrapping of the remainder both happen here. | |
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel | |
else: | |
return module, total_wrapped_numel | |
return module, 0 | |
class _ConfigAutoWrap: | |
""" | |
Helper class to wrap modules based on default config args via a context manager. | |
See :func:`enable_wrap` for more information. | |
""" | |
in_autowrap_context: bool = False # Context flag | |
wrapper_cls: Optional[Callable] = None # The wrapper class | |
kwargs: Dict[str, Any] = {} # Wrapper's args | |
def __init__(self, **kwargs: Dict[str, Any]): | |
self.kwargs = kwargs | |
def enable_autowrap_context(kwargs: Any) -> None: | |
if _ConfigAutoWrap.in_autowrap_context: | |
raise NotImplementedError( | |
"You are already within an autowrap context and we currently do not supported nested autowrap." | |
) | |
_ConfigAutoWrap.in_autowrap_context = True | |
# Get and save the wrapper cls for the context. | |
assert ( | |
"wrapper_cls" in kwargs.keys() | |
), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." | |
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) | |
del kwargs["wrapper_cls"] | |
# Save the rest. | |
_ConfigAutoWrap.kwargs = kwargs | |
def disable_autowrap_context() -> None: | |
_ConfigAutoWrap.in_autowrap_context = False | |
_ConfigAutoWrap.wrapper_cls = None | |
_ConfigAutoWrap.kwargs = {} | |
def __enter__(self) -> None: | |
self.enable_autowrap_context(self.kwargs) | |
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: | |
self.disable_autowrap_context() | |